diff --git a/build_support/generate_modelsmodule.py b/build_support/generate_modelsmodule.py index fa2af8442b..bf588d6b6e 100644 --- a/build_support/generate_modelsmodule.py +++ b/build_support/generate_modelsmodule.py @@ -27,6 +27,7 @@ """ import argparse +import itertools import os import sys from pathlib import Path @@ -109,6 +110,7 @@ def get_models_from_file(model_file): "public Node": "node", "public ClopathArchivingNode": "clopath", "public UrbanczikArchivingNode": "urbanczik", + "public EpropArchivingNode": "neuron", "typedef binary_neuron": "binary", "typedef rate_": "rate", } @@ -227,9 +229,7 @@ def generate_modelsmodule(): 1. the copyright header. 2. a list of generic NEST includes 3. the list of includes for the models to build into NEST - 4. some boilerplate function implementations needed to fulfill the - Module interface - 5. the list of model registration lines for the models to build + 4. the list of model registration lines for the models to build into NEST The code is enriched by structured C++ comments as to make @@ -246,7 +246,16 @@ def generate_modelsmodule(): modeldir.mkdir(parents=True, exist_ok=True) with open(modeldir / fname, "w") as file: file.write(copyright_header.replace("{{file_name}}", fname)) - file.write('\n#include "models.h"\n\n// Generated includes\n#include "config.h"\n') + file.write( + dedent( + """ + #include "models.h" + + // Generated includes + #include "config.h" + """ + ) + ) for model_type, guards_fnames in includes.items(): file.write(f"\n// {model_type.capitalize()} models\n") diff --git a/doc/htmldoc/examples/index.rst b/doc/htmldoc/examples/index.rst index a264c3f452..88a93aed4a 100644 --- a/doc/htmldoc/examples/index.rst +++ b/doc/htmldoc/examples/index.rst @@ -198,6 +198,16 @@ PyNEST examples * :doc:`../auto_examples/evaluate_tsodyks2_synapse` +.. grid:: 1 1 2 3 + + .. grid-item-card:: :doc:`../auto_examples/eprop_plasticity/index` + :img-top: ../static/img/pynest/eprop_supervised_classification_infrastructure.png + + * :doc:`/auto_examples/eprop_plasticity/eprop_supervised_classification_evidence-accumulation` + * :doc:`/auto_examples/eprop_plasticity/eprop_supervised_regression_sine-waves` + * :doc:`/auto_examples/eprop_plasticity/eprop_supervised_regression_handwriting` + * :doc:`/auto_examples/eprop_plasticity/eprop_supervised_regression_infinite-loop` + .. grid:: 1 1 2 3 @@ -332,6 +342,7 @@ PyNEST examples ../auto_examples/astrocytes/astrocyte_interaction ../auto_examples/astrocytes/astrocyte_small_network ../auto_examples/astrocytes/astrocyte_brunel + ../auto_examples/eprop_plasticity/index .. toctree:: :hidden: diff --git a/doc/htmldoc/static/img/pynest/eprop_supervised_classification_infrastructure.png b/doc/htmldoc/static/img/pynest/eprop_supervised_classification_infrastructure.png new file mode 100644 index 0000000000..acb536c69c Binary files /dev/null and b/doc/htmldoc/static/img/pynest/eprop_supervised_classification_infrastructure.png differ diff --git a/doc/htmldoc/whats_new/v3.7/index.rst b/doc/htmldoc/whats_new/v3.7/index.rst index 4f4446d5dc..89601e1d99 100644 --- a/doc/htmldoc/whats_new/v3.7/index.rst +++ b/doc/htmldoc/whats_new/v3.7/index.rst @@ -40,3 +40,25 @@ See examples using astrocyte models: See connectivity documentation: * :ref:`tripartite_connectivity` + + +E-prop plasticity in NEST +------------------------- + +Another new NEST feature is eligibility propagation (e-prop) [1]_, a local and +online learning algorithm for recurrent spiking neural networks (RSNNs) that +serves as a biologically plausible approximation to backpropagation through time +(BPTT). It relies on eligibility traces and neuron-specific learning signals to +compute gradients without the need for error propagation backward in time. This +approach aligns with the brain's learning mechanisms and offers a strong +candidate for efficient training of RSNNs in low-power neuromorphic hardware. + +For further information, see: + +* :doc:`/auto_examples/eprop_plasticity/index` +* :doc:`/models/index_e-prop plasticity` + +.. [1] Bellec G, Scherr F, Subramoney F, Hajek E, Salaj D, Legenstein R, + Maass W (2020). A solution to the learning dilemma for recurrent + networks of spiking neurons. Nature Communications, 11:3625. + https://doi.org/10.1038/s41467-020-17236-y diff --git a/libnestutil/block_vector.h b/libnestutil/block_vector.h index 9ab0b0e9d1..5648804bac 100644 --- a/libnestutil/block_vector.h +++ b/libnestutil/block_vector.h @@ -236,6 +236,14 @@ class BlockVector */ void push_back( const value_type_& value ); + /** + * @brief Move data to the end of the BlockVector. + * @param value Data to be moved to end of BlockVector. + * + * Moves given data to the element at the end of the BlockVector. + */ + void push_back( value_type_&& value ); + /** * Erases all the elements. */ @@ -313,15 +321,17 @@ class BlockVector ///////////////////////////////////////////////////////////// template < typename value_type_ > -inline BlockVector< value_type_ >::BlockVector() - : blockmap_( std::vector< std::vector< value_type_ > >( 1, std::vector< value_type_ >( max_block_size ) ) ) +BlockVector< value_type_ >::BlockVector() + : blockmap_( + std::vector< std::vector< value_type_ > >( 1, std::move( std::vector< value_type_ >( max_block_size ) ) ) ) , finish_( begin() ) { } template < typename value_type_ > -inline BlockVector< value_type_ >::BlockVector( size_t n ) - : blockmap_( std::vector< std::vector< value_type_ > >( 1, std::vector< value_type_ >( max_block_size ) ) ) +BlockVector< value_type_ >::BlockVector( size_t n ) + : blockmap_( + std::vector< std::vector< value_type_ > >( 1, std::move( std::vector< value_type_ >( max_block_size ) ) ) ) , finish_( begin() ) { size_t num_blocks_needed = std::ceil( static_cast< double >( n ) / max_block_size ); @@ -394,7 +404,7 @@ BlockVector< value_type_ >::end() const } template < typename value_type_ > -inline void +void BlockVector< value_type_ >::push_back( const value_type_& value ) { // If this is the last element in the current block, add another block @@ -411,7 +421,24 @@ BlockVector< value_type_ >::push_back( const value_type_& value ) } template < typename value_type_ > -inline void +void +BlockVector< value_type_ >::push_back( value_type_&& value ) +{ + // If this is the last element in the current block, add another block + if ( finish_.block_it_ == finish_.current_block_end_ - 1 ) + { + // Need to get the current position here, then recreate the iterator after we extend the blockmap, + // because after the blockmap is changed the iterator becomes invalid. + const auto current_block = finish_.block_vector_it_ - finish_.block_vector_->blockmap_.begin(); + blockmap_.emplace_back( max_block_size ); + finish_.block_vector_it_ = finish_.block_vector_->blockmap_.begin() + current_block; + } + *finish_ = std::move( value ); + ++finish_; +} + +template < typename value_type_ > +void BlockVector< value_type_ >::clear() { for ( auto it = blockmap_.begin(); it != blockmap_.end(); ++it ) @@ -442,7 +469,7 @@ BlockVector< value_type_ >::size() const } template < typename value_type_ > -inline typename BlockVector< value_type_ >::iterator +typename BlockVector< value_type_ >::iterator BlockVector< value_type_ >::erase( const_iterator first, const_iterator last ) { assert( first.block_vector_ == this ); @@ -495,7 +522,7 @@ BlockVector< value_type_ >::erase( const_iterator first, const_iterator last ) } template < typename value_type_ > -inline void +void BlockVector< value_type_ >::print_blocks() const { std::cerr << "this: \t\t" << this << "\n"; diff --git a/models/CMakeLists.txt b/models/CMakeLists.txt index 7892e2cfd7..4b861b13f3 100644 --- a/models/CMakeLists.txt +++ b/models/CMakeLists.txt @@ -26,6 +26,7 @@ set(models_sources rate_neuron_ipn.h rate_neuron_ipn_impl.h rate_neuron_opn.h rate_neuron_opn_impl.h rate_transformer_node.h rate_transformer_node_impl.h + weight_optimizer.h weight_optimizer.cpp ${MODELS_SOURCES_GENERATED} ) diff --git a/models/eprop_iaf_adapt_bsshslm_2020.cpp b/models/eprop_iaf_adapt_bsshslm_2020.cpp new file mode 100644 index 0000000000..58cb06b9e0 --- /dev/null +++ b/models/eprop_iaf_adapt_bsshslm_2020.cpp @@ -0,0 +1,513 @@ +/* + * eprop_iaf_adapt_bsshslm_2020.cpp + * + * This file is part of NEST. + * + * Copyright (C) 2004 The NEST Initiative + * + * NEST is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 2 of the License, or + * (at your option) any later version. + * + * NEST is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with NEST. If not, see . + * + */ + +// nest models +#include "eprop_iaf_adapt_bsshslm_2020.h" + +// C++ +#include + +// libnestutil +#include "dict_util.h" +#include "numerics.h" + +// nestkernel +#include "exceptions.h" +#include "kernel_manager.h" +#include "nest_impl.h" +#include "universal_data_logger_impl.h" + +// sli +#include "dictutils.h" + +namespace nest +{ + +void +register_eprop_iaf_adapt_bsshslm_2020( const std::string& name ) +{ + register_node_model< eprop_iaf_adapt_bsshslm_2020 >( name ); +} + +/* ---------------------------------------------------------------- + * Recordables map + * ---------------------------------------------------------------- */ + +RecordablesMap< eprop_iaf_adapt_bsshslm_2020 > eprop_iaf_adapt_bsshslm_2020::recordablesMap_; + +template <> +void +RecordablesMap< eprop_iaf_adapt_bsshslm_2020 >::create() +{ + insert_( names::adaptation, &eprop_iaf_adapt_bsshslm_2020::get_adaptation_ ); + insert_( names::V_th_adapt, &eprop_iaf_adapt_bsshslm_2020::get_v_th_adapt_ ); + insert_( names::learning_signal, &eprop_iaf_adapt_bsshslm_2020::get_learning_signal_ ); + insert_( names::surrogate_gradient, &eprop_iaf_adapt_bsshslm_2020::get_surrogate_gradient_ ); + insert_( names::V_m, &eprop_iaf_adapt_bsshslm_2020::get_v_m_ ); +} + +/* ---------------------------------------------------------------- + * Default constructors for parameters, state, and buffers + * ---------------------------------------------------------------- */ + +eprop_iaf_adapt_bsshslm_2020::Parameters_::Parameters_() + : adapt_beta_( 1.0 ) + , adapt_tau_( 10.0 ) + , C_m_( 250.0 ) + , c_reg_( 0.0 ) + , E_L_( -70.0 ) + , f_target_( 0.01 ) + , gamma_( 0.3 ) + , I_e_( 0.0 ) + , regular_spike_arrival_( true ) + , surrogate_gradient_function_( "piecewise_linear" ) + , t_ref_( 2.0 ) + , tau_m_( 10.0 ) + , V_min_( -std::numeric_limits< double >::max() ) + , V_th_( -55.0 - E_L_ ) +{ +} + +eprop_iaf_adapt_bsshslm_2020::State_::State_() + : adapt_( 0.0 ) + , v_th_adapt_( 15.0 ) + , learning_signal_( 0.0 ) + , r_( 0 ) + , surrogate_gradient_( 0.0 ) + , i_in_( 0.0 ) + , v_m_( 0.0 ) + , z_( 0.0 ) + , z_in_( 0.0 ) +{ +} + +eprop_iaf_adapt_bsshslm_2020::Buffers_::Buffers_( eprop_iaf_adapt_bsshslm_2020& n ) + : logger_( n ) +{ +} + +eprop_iaf_adapt_bsshslm_2020::Buffers_::Buffers_( const Buffers_&, eprop_iaf_adapt_bsshslm_2020& n ) + : logger_( n ) +{ +} + +/* ---------------------------------------------------------------- + * Getter and setter functions for parameters and state + * ---------------------------------------------------------------- */ + +void +eprop_iaf_adapt_bsshslm_2020::Parameters_::get( DictionaryDatum& d ) const +{ + def< double >( d, names::adapt_beta, adapt_beta_ ); + def< double >( d, names::adapt_tau, adapt_tau_ ); + def< double >( d, names::C_m, C_m_ ); + def< double >( d, names::c_reg, c_reg_ ); + def< double >( d, names::E_L, E_L_ ); + def< double >( d, names::f_target, f_target_ ); + def< double >( d, names::gamma, gamma_ ); + def< double >( d, names::I_e, I_e_ ); + def< bool >( d, names::regular_spike_arrival, regular_spike_arrival_ ); + def< std::string >( d, names::surrogate_gradient_function, surrogate_gradient_function_ ); + def< double >( d, names::t_ref, t_ref_ ); + def< double >( d, names::tau_m, tau_m_ ); + def< double >( d, names::V_min, V_min_ + E_L_ ); + def< double >( d, names::V_th, V_th_ + E_L_ ); +} + +double +eprop_iaf_adapt_bsshslm_2020::Parameters_::set( const DictionaryDatum& d, Node* node ) +{ + // if leak potential is changed, adjust all variables defined relative to it + const double ELold = E_L_; + updateValueParam< double >( d, names::E_L, E_L_, node ); + const double delta_EL = E_L_ - ELold; + + V_th_ -= updateValueParam< double >( d, names::V_th, V_th_, node ) ? E_L_ : delta_EL; + V_min_ -= updateValueParam< double >( d, names::V_min, V_min_, node ) ? E_L_ : delta_EL; + + updateValueParam< double >( d, names::adapt_beta, adapt_beta_, node ); + updateValueParam< double >( d, names::adapt_tau, adapt_tau_, node ); + updateValueParam< double >( d, names::C_m, C_m_, node ); + updateValueParam< double >( d, names::c_reg, c_reg_, node ); + + if ( updateValueParam< double >( d, names::f_target, f_target_, node ) ) + { + f_target_ /= 1000.0; // convert from spikes/s to spikes/ms + } + + updateValueParam< double >( d, names::gamma, gamma_, node ); + updateValueParam< double >( d, names::I_e, I_e_, node ); + updateValueParam< bool >( d, names::regular_spike_arrival, regular_spike_arrival_, node ); + updateValueParam< std::string >( d, names::surrogate_gradient_function, surrogate_gradient_function_, node ); + updateValueParam< double >( d, names::t_ref, t_ref_, node ); + updateValueParam< double >( d, names::tau_m, tau_m_, node ); + + if ( adapt_beta_ < 0 ) + { + throw BadProperty( "Threshold adaptation prefactor adapt_beta ≥ 0 required." ); + } + + if ( adapt_tau_ <= 0 ) + { + throw BadProperty( "Threshold adaptation time constant adapt_tau > 0 required." ); + } + + if ( C_m_ <= 0 ) + { + throw BadProperty( "Membrane capacitance C_m > 0 required." ); + } + + if ( c_reg_ < 0 ) + { + throw BadProperty( "Firing rate regularization prefactor c_reg ≥ 0 required." ); + } + + if ( f_target_ < 0 ) + { + throw BadProperty( "Firing rate regularization target rate f_target ≥ 0 required." ); + } + + if ( gamma_ < 0.0 or 1.0 <= gamma_ ) + { + throw BadProperty( "Surrogate gradient / pseudo-derivative scaling gamma from interval [0,1) required." ); + } + + if ( surrogate_gradient_function_ != "piecewise_linear" ) + { + throw BadProperty( + "Surrogate gradient / pseudo derivate function surrogate_gradient_function from [\"piecewise_linear\"] " + "required." ); + } + + if ( tau_m_ <= 0 ) + { + throw BadProperty( "Membrane time constant tau_m > 0 required." ); + } + + if ( t_ref_ < 0 ) + { + throw BadProperty( "Refractory time t_ref ≥ 0 required." ); + } + + if ( surrogate_gradient_function_ == "piecewise_linear" and fabs( V_th_ ) < 1e-6 ) + { + throw BadProperty( + "Relative threshold voltage V_th-E_L ≠ 0 required if surrogate_gradient_function is \"piecewise_linear\"." ); + } + + if ( V_th_ < V_min_ ) + { + throw BadProperty( "Spike threshold voltage V_th ≥ minimal voltage V_min required." ); + } + + return delta_EL; +} + +void +eprop_iaf_adapt_bsshslm_2020::State_::get( DictionaryDatum& d, const Parameters_& p ) const +{ + def< double >( d, names::adaptation, adapt_ ); + def< double >( d, names::V_m, v_m_ + p.E_L_ ); + def< double >( d, names::V_th_adapt, v_th_adapt_ + p.E_L_ ); + def< double >( d, names::surrogate_gradient, surrogate_gradient_ ); + def< double >( d, names::learning_signal, learning_signal_ ); +} + +void +eprop_iaf_adapt_bsshslm_2020::State_::set( const DictionaryDatum& d, const Parameters_& p, double delta_EL, Node* node ) +{ + v_m_ -= updateValueParam< double >( d, names::V_m, v_m_, node ) ? p.E_L_ : delta_EL; + + // adaptive threshold can only be set indirectly via the adaptation variable + if ( updateValueParam< double >( d, names::adaptation, adapt_, node ) ) + { + // if E_L changed in this SetStatus call, p.V_th_ has been adjusted and no further action is needed + v_th_adapt_ = p.V_th_ + p.adapt_beta_ * adapt_; + } + else + { + // adjust voltage to change in E_L + v_th_adapt_ -= delta_EL; + } +} + +/* ---------------------------------------------------------------- + * Default and copy constructor for node + * ---------------------------------------------------------------- */ + +eprop_iaf_adapt_bsshslm_2020::eprop_iaf_adapt_bsshslm_2020() + : EpropArchivingNodeRecurrent() + , P_() + , S_() + , B_( *this ) +{ + recordablesMap_.create(); +} + +eprop_iaf_adapt_bsshslm_2020::eprop_iaf_adapt_bsshslm_2020( const eprop_iaf_adapt_bsshslm_2020& n ) + : EpropArchivingNodeRecurrent( n ) + , P_( n.P_ ) + , S_( n.S_ ) + , B_( n.B_, *this ) +{ +} + +/* ---------------------------------------------------------------- + * Node initialization functions + * ---------------------------------------------------------------- */ + +void +eprop_iaf_adapt_bsshslm_2020::init_buffers_() +{ + B_.spikes_.clear(); // includes resize + B_.currents_.clear(); // includes resize + B_.logger_.reset(); // includes resize +} + +void +eprop_iaf_adapt_bsshslm_2020::pre_run_hook() +{ + B_.logger_.init(); // ensures initialization in case multimeter connected after Simulate + + V_.RefractoryCounts_ = Time( Time::ms( P_.t_ref_ ) ).get_steps(); + + if ( P_.surrogate_gradient_function_ == "piecewise_linear" ) + { + compute_surrogate_gradient = &eprop_iaf_adapt_bsshslm_2020::compute_piecewise_linear_derivative; + } + + // calculate the entries of the propagator matrix for the evolution of the state vector + + const double dt = Time::get_resolution().get_ms(); + + V_.P_v_m_ = std::exp( -dt / P_.tau_m_ ); // called alpha in reference [1]_ + V_.P_i_in_ = P_.tau_m_ / P_.C_m_ * ( 1.0 - V_.P_v_m_ ); + V_.P_z_in_ = P_.regular_spike_arrival_ ? 1.0 : 1.0 - V_.P_v_m_; + V_.P_adapt_ = std::exp( -dt / P_.adapt_tau_ ); +} + +long +eprop_iaf_adapt_bsshslm_2020::get_shift() const +{ + return offset_gen_ + delay_in_rec_; +} + +bool +eprop_iaf_adapt_bsshslm_2020::is_eprop_recurrent_node() const +{ + return true; +} + +/* ---------------------------------------------------------------- + * Update function + * ---------------------------------------------------------------- */ + +void +eprop_iaf_adapt_bsshslm_2020::update( Time const& origin, const long from, const long to ) +{ + const long update_interval = kernel().simulation_manager.get_eprop_update_interval().get_steps(); + const bool with_reset = kernel().simulation_manager.get_eprop_reset_neurons_on_update(); + const long shift = get_shift(); + + for ( long lag = from; lag < to; ++lag ) + { + const long t = origin.get_steps() + lag; + const long interval_step = ( t - shift ) % update_interval; + + if ( interval_step == 0 ) + { + erase_used_firing_rate_reg_history(); + erase_used_update_history(); + erase_used_eprop_history(); + + if ( with_reset ) + { + S_.v_m_ = 0.0; + S_.adapt_ = 0.0; + S_.r_ = 0; + S_.z_ = 0.0; + } + } + + S_.z_in_ = B_.spikes_.get_value( lag ); + + S_.v_m_ = V_.P_i_in_ * S_.i_in_ + V_.P_z_in_ * S_.z_in_ + V_.P_v_m_ * S_.v_m_; + S_.v_m_ -= P_.V_th_ * S_.z_; + S_.v_m_ = std::max( S_.v_m_, P_.V_min_ ); + + S_.adapt_ = V_.P_adapt_ * S_.adapt_ + S_.z_; + S_.v_th_adapt_ = P_.V_th_ + P_.adapt_beta_ * S_.adapt_; + + S_.z_ = 0.0; + + S_.surrogate_gradient_ = ( this->*compute_surrogate_gradient )(); + + write_surrogate_gradient_to_history( t, S_.surrogate_gradient_ ); + + if ( S_.v_m_ >= S_.v_th_adapt_ and S_.r_ == 0 ) + { + count_spike(); + + SpikeEvent se; + kernel().event_delivery_manager.send( *this, se, lag ); + + S_.z_ = 1.0; + + if ( V_.RefractoryCounts_ > 0 ) + { + S_.r_ = V_.RefractoryCounts_; + } + } + + if ( interval_step == update_interval - 1 ) + { + write_firing_rate_reg_to_history( t, P_.f_target_, P_.c_reg_ ); + reset_spike_count(); + } + + S_.learning_signal_ = get_learning_signal_from_history( t ); + + if ( S_.r_ > 0 ) + { + --S_.r_; + } + + S_.i_in_ = B_.currents_.get_value( lag ) + P_.I_e_; + + B_.logger_.record_data( t ); + } +} + +/* ---------------------------------------------------------------- + * Surrogate gradient functions + * ---------------------------------------------------------------- */ + +double +eprop_iaf_adapt_bsshslm_2020::compute_piecewise_linear_derivative() +{ + if ( S_.r_ > 0 ) + { + return 0.0; + } + + return P_.gamma_ * std::max( 0.0, 1.0 - std::fabs( ( S_.v_m_ - S_.v_th_adapt_ ) / P_.V_th_ ) ) / P_.V_th_; +} + +/* ---------------------------------------------------------------- + * Event handling functions + * ---------------------------------------------------------------- */ + +void +eprop_iaf_adapt_bsshslm_2020::handle( SpikeEvent& e ) +{ + assert( e.get_delay_steps() > 0 ); + + B_.spikes_.add_value( + e.get_rel_delivery_steps( kernel().simulation_manager.get_slice_origin() ), e.get_weight() * e.get_multiplicity() ); +} + +void +eprop_iaf_adapt_bsshslm_2020::handle( CurrentEvent& e ) +{ + assert( e.get_delay_steps() > 0 ); + + B_.currents_.add_value( + e.get_rel_delivery_steps( kernel().simulation_manager.get_slice_origin() ), e.get_weight() * e.get_current() ); +} + +void +eprop_iaf_adapt_bsshslm_2020::handle( LearningSignalConnectionEvent& e ) +{ + for ( auto it_event = e.begin(); it_event != e.end(); ) + { + const long time_step = e.get_stamp().get_steps(); + const double weight = e.get_weight(); + const double error_signal = e.get_coeffvalue( it_event ); // get_coeffvalue advances iterator + const double learning_signal = weight * error_signal; + + write_learning_signal_to_history( time_step, learning_signal ); + } +} + +void +eprop_iaf_adapt_bsshslm_2020::handle( DataLoggingRequest& e ) +{ + B_.logger_.handle( e ); +} + +double +eprop_iaf_adapt_bsshslm_2020::compute_gradient( std::vector< long >& presyn_isis, + const long t_previous_update, + const long t_previous_trigger_spike, + const double kappa, + const bool average_gradient ) +{ + auto eprop_hist_it = get_eprop_history( t_previous_trigger_spike ); + + double e = 0.0; // eligibility trace + double e_bar = 0.0; // low-pass filtered eligibility trace + double epsilon = 0.0; // adaptive component of eligibility vector + double grad = 0.0; // gradient value to be calculated + double L = 0.0; // learning signal + double psi = 0.0; // surrogate gradient + double sum_e = 0.0; // sum of eligibility traces + double z = 0.0; // spiking variable + double z_bar = 0.0; // low-pass filtered spiking variable + + for ( long presyn_isi : presyn_isis ) + { + z = 1.0; // set spiking variable to 1 for each incoming spike + + for ( long t = 0; t < presyn_isi; ++t ) + { + assert( eprop_hist_it != eprop_history_.end() ); + + psi = eprop_hist_it->surrogate_gradient_; + L = eprop_hist_it->learning_signal_; + + z_bar = V_.P_v_m_ * z_bar + V_.P_z_in_ * z; + e = psi * ( z_bar - P_.adapt_beta_ * epsilon ); + epsilon = psi * z_bar + ( V_.P_adapt_ - psi * P_.adapt_beta_ ) * epsilon; + e_bar = kappa * e_bar + ( 1.0 - kappa ) * e; + grad += L * e_bar; + sum_e += e; + z = 0.0; // set spiking variable to 0 between spikes + + ++eprop_hist_it; + } + } + presyn_isis.clear(); + + const long learning_window = kernel().simulation_manager.get_eprop_learning_window().get_steps(); + if ( average_gradient ) + { + grad /= learning_window; + } + + const long update_interval = kernel().simulation_manager.get_eprop_update_interval().get_steps(); + const auto it_reg_hist = get_firing_rate_reg_history( t_previous_update + get_shift() + update_interval ); + grad += it_reg_hist->firing_rate_reg_ * sum_e; + + return grad; +} + +} // namespace nest diff --git a/models/eprop_iaf_adapt_bsshslm_2020.h b/models/eprop_iaf_adapt_bsshslm_2020.h new file mode 100644 index 0000000000..f67f6673ff --- /dev/null +++ b/models/eprop_iaf_adapt_bsshslm_2020.h @@ -0,0 +1,586 @@ +/* + * eprop_iaf_adapt_bsshslm_2020.h + * + * This file is part of NEST. + * + * Copyright (C) 2004 The NEST Initiative + * + * NEST is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 2 of the License, or + * (at your option) any later version. + * + * NEST is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with NEST. If not, see . + * + */ + +#ifndef EPROP_IAF_ADAPT_BSSHSLM_2020_H +#define EPROP_IAF_ADAPT_BSSHSLM_2020_H + +// nestkernel +#include "connection.h" +#include "eprop_archiving_node.h" +#include "eprop_archiving_node_impl.h" +#include "event.h" +#include "nest_types.h" +#include "ring_buffer.h" +#include "universal_data_logger.h" + +namespace nest +{ + +/* BeginUserDocs: neuron, e-prop plasticity, current-based, integrate-and-fire, adaptive threshold + +Short description ++++++++++++++++++ + +Current-based leaky integrate-and-fire neuron model with delta-shaped +postsynaptic currents and threshold adaptation for e-prop plasticity + +Description ++++++++++++ + +``eprop_iaf_adapt_bsshslm_2020`` is an implementation of a leaky integrate-and-fire +neuron model with delta-shaped postsynaptic currents and threshold adaptation +used for eligibility propagation (e-prop) plasticity. + +E-prop plasticity was originally introduced and implemented in TensorFlow in [1]_. + +The suffix ``_bsshslm_2020`` follows the NEST convention to indicate in the +model name the paper that introduced it by the first letter of the authors' last +names and the publication year. + + .. note:: + The neuron dynamics of the ``eprop_iaf_adapt_bsshslm_2020`` model (excluding + e-prop plasticity and the threshold adaptation) are similar to the neuron + dynamics of the ``iaf_psc_delta`` model, with minor differences, such as the + propagator of the post-synaptic current and the voltage reset upon a spike. + +The membrane voltage time course :math:`v_j^t` of the neuron :math:`j` is given by: + +.. math:: + v_j^t &= \alpha v_j^{t-1}+\sum_{i \neq j}W_{ji}^\mathrm{rec}z_i^{t-1} + + \sum_i W_{ji}^\mathrm{in}x_i^t-z_j^{t-1}v_\mathrm{th} \,, \\ + \alpha &= e^{-\frac{\Delta t}{\tau_\mathrm{m}}} \,, + +whereby :math:`W_{ji}^\mathrm{rec}` and :math:`W_{ji}^\mathrm{in}` are the recurrent and +input synaptic weights, and :math:`z_i^{t-1}` and :math:`x_i^t` are the +recurrent and input presynaptic spike state variables, respectively. + +Descriptions of further parameters and variables can be found in the table below. + +The threshold adaptation is given by: + +.. math:: + A_j^t &= v_\mathrm{th} + \beta a_j^t \,, \\ + a_j^t &= \rho a_j^{t-1} + z_j^{t-1} \,, \\ + \rho &= e^{-\frac{\Delta t}{\tau_\mathrm{a}}} \,. + +The spike state variable is expressed by a Heaviside function: + +.. math:: + z_j^t = H\left(v_j^t-A_j^t\right) \,. + +If the membrane voltage crosses the adaptive threshold voltage :math:`A_j^t`, a spike is +emitted and the membrane voltage is reduced by :math:`v_\text{th}` in the next +time step. After the time step of the spike emission, the neuron is not +able to spike for an absolute refractory period :math:`t_\text{ref}`. + +An additional state variable and the corresponding differential equation +represents a piecewise constant external current. + +Furthermore, the pseudo derivative of the membrane voltage needed for e-prop +plasticity is calculated: + +.. math:: + \psi_j^t = \frac{\gamma}{v_\text{th}} \text{max} + \left(0, 1-\left| \frac{v_j^t-A_j^t}{v_\text{th}}\right| \right) \,. + +See the documentation on the ``iaf_psc_delta`` neuron model for more information +on the integration of the subthreshold dynamics. + +The change of the synaptic weight is calculated from the gradient :math:`g` of +the loss :math:`E` with respect to the synaptic weight :math:`W_{ji}`: +:math:`\frac{\mathrm{d}{E}}{\mathrm{d}{W_{ij}}}=g` +which depends on the presynaptic +spikes :math:`z_i^{t-1}`, the surrogate-gradient / pseudo-derivative of the postsynaptic membrane +voltage :math:`\psi_j^t` (which together form the eligibility trace +:math:`e_{ji}^t`), and the learning signal :math:`L_j^t` emitted by the readout +neurons. + +.. math:: + \frac{\mathrm{d}E}{\mathrm{d}W_{ji}} = g &= \sum_t L_j^t \bar{e}_{ji}^t, \\ + e_{ji}^t &= \psi_j^t \left(\bar{z}_i^{t-1} - \beta \epsilon_{ji,a}^{t-1}\right)\,, \\ + \epsilon^{t-1}_{ji,\text{a}} &= \psi_j^{t-1}\bar{z}_i^{t-2} + \left( \rho - \psi_j^{t-1} \beta \right) + \epsilon^{t-2}_{ji,a}\,. \\ + +The eligibility trace and the presynaptic spike trains are low-pass filtered +with some exponential kernels: + +.. math:: + \bar{e}_{ji}^t&=\mathcal{F}_\kappa(e_{ji}^t) \;\text{with}\, \kappa=e^{-\frac{\Delta t}{ + \tau_\text{m,out}}}\,,\\ + \bar{z}_i^t&=\mathcal{F}_\alpha(z_i^t)\,,\\ + \mathcal{F}_\alpha(z_i^t) &= \alpha\, \mathcal{F}_\alpha(z_i^{t-1}) + z_i^t + \;\text{with}\, \mathcal{F}_\alpha(z_i^0)=z_i^0\,\,, + +whereby :math:`\tau_\text{m,out}` is the membrane time constant of the readout neuron. + +Furthermore, a firing rate regularization mechanism keeps the average firing +rate :math:`f^\text{av}_j` of the postsynaptic neuron close to a target firing rate +:math:`f^\text{target}`. The gradient :math:`g^\text{reg}` of the regularization loss :math:`E^\text{reg}` +with respect to the synaptic weight :math:`W_{ji}` is given by: + +.. math:: + \frac{\mathrm{d}E^\text{reg}}{\mathrm{d}W_{ji}} = g^\text{reg} = c_\text{reg} + \sum_t \frac{1}{Tn_\text{trial}} \left( f^\text{target}-f^\text{av}_j\right)e_{ji}^t\,, + +whereby :math:`c_\text{reg}` scales the overall regularization and the average +is taken over the time that passed since the previous update, that is, the number of +trials :math:`n_\text{trial}` times the duration of an update interval :math:`T`. + +The overall gradient is given by the addition of the two gradients. + +For more information on e-prop plasticity, see the documentation on the other e-prop models: + + * :doc:`eprop_iaf_bsshslm_2020<../models/eprop_iaf_bsshslm_2020/>` + * :doc:`eprop_readout_bsshslm_2020<../models/eprop_readout_bsshslm_2020/>` + * :doc:`eprop_synapse_bsshslm_2020<../models/eprop_synapse_bsshslm_2020/>` + * :doc:`eprop_learning_signal_connection_bsshslm_2020<../models/eprop_learning_signal_connection_bsshslm_2020/>` + +Details on the event-based NEST implementation of e-prop can be found in [2]_. + +Parameters +++++++++++ + +The following parameters can be set in the status dictionary. + +=========================== ======= ======================= ================ =================================== +**Neuron parameters** +---------------------------------------------------------------------------------------------------------------- +Parameter Unit Math equivalent Default Description +=========================== ======= ======================= ================ =================================== +adapt_beta :math:`\beta` 1.0 Prefactor of the threshold + adaptation +adapt_tau ms :math:`\tau_\text{a}` 10.0 Time constant of the threshold + adaptation +C_m pF :math:`C_\text{m}` 250.0 Capacitance of the membrane +c_reg :math:`c_\text{reg}` 0.0 Prefactor of firing rate + regularization +E_L mV :math:`E_\text{L}` -70.0 Leak / resting membrane potential +f_target Hz :math:`f^\text{target}` 10.0 Target firing rate of rate + regularization +gamma :math:`\gamma` 0.3 Scaling of surrogate gradient / + pseudo-derivative of + membrane voltage +I_e pA :math:`I_\text{e}` 0.0 Constant external input current +regular_spike_arrival Boolean True If True, the input spikes arrive at + the end of the time step, if False + at the beginning (determines PSC + scale) +surrogate_gradient_function :math:`\psi` piecewise_linear Surrogate gradient / + pseudo-derivative function + ["piecewise_linear"] +t_ref ms :math:`t_\text{ref}` 2.0 Duration of the refractory period +tau_m ms :math:`\tau_\text{m}` 10.0 Time constant of the membrane +V_min mV :math:`v_\text{min}` -1.79e+308 Absolute lower bound of the + membrane voltage +V_th mV :math:`v_\text{th}` -55.0 Spike threshold voltage +=========================== ======= ======================= ================ =================================== + +The following state variables evolve during simulation. + +================== ==== =============== ============= ======================== +**Neuron state variables and recordables** +------------------------------------------------------------------------------ +State variable Unit Math equivalent Initial value Description +================== ==== =============== ============= ======================== +adaptation :math:`a_j` 0.0 Adaptation variable +learning_signal :math:`L_j` 0.0 Learning signal +surrogate_gradient :math:`\psi_j` 0.0 Surrogate gradient +V_m mV :math:`v_j` -70.0 Membrane voltage +V_th_adapt mV :math:`A_j` -55.0 Adapting spike threshold +================== ==== =============== ============= ======================== + +Recordables ++++++++++++ + +The following variables can be recorded: + + - adaptation variable ``adaptation`` + - adapting spike threshold ``V_th_adapt`` + - learning signal ``learning_signal`` + - membrane potential ``V_m`` + - surrogate gradient ``surrogate_gradient`` + +Usage ++++++ + +This model can only be used in combination with the other e-prop models, +whereby the network architecture requires specific wiring, input, and output. +The usage is demonstrated in several +:doc:`supervised regression and classification tasks <../auto_examples/eprop_plasticity/index>` +reproducing among others the original proof-of-concept tasks in [1]_. + +References +++++++++++ + +.. [1] Bellec G, Scherr F, Subramoney F, Hajek E, Salaj D, Legenstein R, + Maass W (2020). A solution to the learning dilemma for recurrent + networks of spiking neurons. Nature Communications, 11:3625. + https://doi.org/10.1038/s41467-020-17236-y +.. [2] Korcsak-Gorzo A, Stapmanns J, Espinoza Valverde JA, Dahmen D, + van Albada SJ, Bolten M, Diesmann M. Event-based implementation of + eligibility propagation (in preparation) + +Sends +++++++++ + +SpikeEvent + +Receives +++++++++ + +SpikeEvent, CurrentEvent, LearningSignalConnectionEvent, DataLoggingRequest + +See also +++++++++ + +Examples using this model +++++++++++++++++++++++++++ + +.. listexamples:: eprop_iaf_adapt_bsshslm_2020 + +EndUserDocs */ + +void register_eprop_iaf_adapt_bsshslm_2020( const std::string& name ); + +/** + * Class implementing a current-based leaky integrate-and-fire neuron model with delta-shaped postsynaptic currents and + * threshold adaptation for e-prop plasticity according to Bellec et al (2020). + */ +class eprop_iaf_adapt_bsshslm_2020 : public EpropArchivingNodeRecurrent +{ + +public: + //! Default constructor. + eprop_iaf_adapt_bsshslm_2020(); + + //! Copy constructor. + eprop_iaf_adapt_bsshslm_2020( const eprop_iaf_adapt_bsshslm_2020& ); + + using Node::handle; + using Node::handles_test_event; + + size_t send_test_event( Node&, size_t, synindex, bool ) override; + + void handle( SpikeEvent& ) override; + void handle( CurrentEvent& ) override; + void handle( LearningSignalConnectionEvent& ) override; + void handle( DataLoggingRequest& ) override; + + size_t handles_test_event( SpikeEvent&, size_t ) override; + size_t handles_test_event( CurrentEvent&, size_t ) override; + size_t handles_test_event( LearningSignalConnectionEvent&, size_t ) override; + size_t handles_test_event( DataLoggingRequest&, size_t ) override; + + void get_status( DictionaryDatum& ) const override; + void set_status( const DictionaryDatum& ) override; + + double compute_gradient( std::vector< long >& presyn_isis, + const long t_previous_update, + const long t_previous_trigger_spike, + const double kappa, + const bool average_gradient ) override; + + void pre_run_hook() override; + long get_shift() const override; + bool is_eprop_recurrent_node() const override; + void update( Time const&, const long, const long ) override; + +protected: + void init_buffers_() override; + +private: + //! Compute the piecewise linear surrogate gradient. + double compute_piecewise_linear_derivative(); + + //! Compute the surrogate gradient. + double ( eprop_iaf_adapt_bsshslm_2020::*compute_surrogate_gradient )(); + + //! Map for storing a static set of recordables. + friend class RecordablesMap< eprop_iaf_adapt_bsshslm_2020 >; + + //! Logger for universal data supporting the data logging request / reply mechanism. Populated with a recordables map. + friend class UniversalDataLogger< eprop_iaf_adapt_bsshslm_2020 >; + + //! Structure of parameters. + struct Parameters_ + { + //! Prefactor of the threshold adaptation. + double adapt_beta_; + + //! Time constant of the threshold adaptation (ms). + double adapt_tau_; + + //! Capacitance of the membrane (pF). + double C_m_; + + //! Prefactor of firing rate regularization. + double c_reg_; + + //! Leak / resting membrane potential (mV). + double E_L_; + + //! Target firing rate of rate regularization (spikes/s). + double f_target_; + + //! Scaling of surrogate-gradient / pseudo-derivative of membrane voltage. + double gamma_; + + //! Constant external input current (pA). + double I_e_; + + //! If True, the input spikes arrive at the beginning of the time step, if False at the end (determines PSC scale). + bool regular_spike_arrival_; + + //! Surrogate gradient / pseudo-derivative function ["piecewise_linear"]. + std::string surrogate_gradient_function_; + + //! Duration of the refractory period (ms). + double t_ref_; + + //! Time constant of the membrane (ms). + double tau_m_; + + //! Absolute lower bound of the membrane voltage relative to the leak membrane potential (mV). + double V_min_; + + //! Spike threshold voltage relative to the leak membrane potential (mV). + double V_th_; + + //! Default constructor. + Parameters_(); + + //! Get the parameters and their values. + void get( DictionaryDatum& ) const; + + //! Set the parameters and throw errors in case of invalid values. + double set( const DictionaryDatum&, Node* ); + }; + + //! Structure of state variables. + struct State_ + { + //! Adaptation variable. + double adapt_; + + //! Adapting spike threshold voltage. + double v_th_adapt_; + + //! Learning signal. Sum of weighted error signals coming from the readout neurons. + double learning_signal_; + + //! Number of remaining refractory steps. + int r_; + + //! Surrogate gradient / pseudo-derivative of the membrane voltage. + double surrogate_gradient_; + + //! Input current (pA). + double i_in_; + + //! Membrane voltage relative to the leak membrane potential (mV). + double v_m_; + + //! Binary spike variable - 1.0 if the neuron has spiked in the previous time step and 0.0 otherwise. + double z_; + + //! Binary input spike variables - 1.0 if the neuron has spiked in the previous time step and 0.0 otherwise. + double z_in_; + + //! Default constructor. + State_(); + + //! Get the state variables and their values. + void get( DictionaryDatum&, const Parameters_& ) const; + + //! Set the state variables. + void set( const DictionaryDatum&, const Parameters_&, double, Node* ); + }; + + //! Structure of buffers. + struct Buffers_ + { + //! Default constructor. + Buffers_( eprop_iaf_adapt_bsshslm_2020& ); + + //! Copy constructor. + Buffers_( const Buffers_&, eprop_iaf_adapt_bsshslm_2020& ); + + //! Buffer for incoming spikes. + RingBuffer spikes_; + + //! Buffer for incoming currents. + RingBuffer currents_; + + //! Logger for universal data. + UniversalDataLogger< eprop_iaf_adapt_bsshslm_2020 > logger_; + }; + + //! Structure of general variables. + struct Variables_ + { + //! Propagator matrix entry for evolving the membrane voltage. + double P_v_m_; + + //! Propagator matrix entry for evolving the incoming spike variables. + double P_z_in_; + + //! Propagator matrix entry for evolving the incoming currents. + double P_i_in_; + + //! Propagator matrix entry for evolving the adaptation. + double P_adapt_; + + //! Total refractory steps. + int RefractoryCounts_; + }; + + //! Get the current value of the membrane voltage. + double + get_v_m_() const + { + return S_.v_m_ + P_.E_L_; + } + + //! Get the current value of the surrogate gradient. + double + get_surrogate_gradient_() const + { + return S_.surrogate_gradient_; + } + + //! Get the current value of the learning signal. + double + get_learning_signal_() const + { + return S_.learning_signal_; + } + + //! Get the current value of the adapting threshold. + double + get_v_th_adapt_() const + { + return S_.v_th_adapt_ + P_.E_L_; + } + + //! Get the current value of the adaptation. + double + get_adaptation_() const + { + return S_.adapt_; + } + + // the order in which the structure instances are defined is important for speed + + //!< Structure of parameters. + Parameters_ P_; + + //!< Structure of state variables. + State_ S_; + + //!< Structure of general variables. + Variables_ V_; + + //!< Structure of buffers. + Buffers_ B_; + + //! Map storing a static set of recordables. + static RecordablesMap< eprop_iaf_adapt_bsshslm_2020 > recordablesMap_; +}; + +inline size_t +eprop_iaf_adapt_bsshslm_2020::send_test_event( Node& target, size_t receptor_type, synindex, bool ) +{ + SpikeEvent e; + e.set_sender( *this ); + return target.handles_test_event( e, receptor_type ); +} + +inline size_t +eprop_iaf_adapt_bsshslm_2020::handles_test_event( SpikeEvent&, size_t receptor_type ) +{ + if ( receptor_type != 0 ) + { + throw UnknownReceptorType( receptor_type, get_name() ); + } + + return 0; +} + +inline size_t +eprop_iaf_adapt_bsshslm_2020::handles_test_event( CurrentEvent&, size_t receptor_type ) +{ + if ( receptor_type != 0 ) + { + throw UnknownReceptorType( receptor_type, get_name() ); + } + + return 0; +} + +inline size_t +eprop_iaf_adapt_bsshslm_2020::handles_test_event( LearningSignalConnectionEvent&, size_t receptor_type ) +{ + if ( receptor_type != 0 ) + { + throw UnknownReceptorType( receptor_type, get_name() ); + } + + return 0; +} + +inline size_t +eprop_iaf_adapt_bsshslm_2020::handles_test_event( DataLoggingRequest& dlr, size_t receptor_type ) +{ + if ( receptor_type != 0 ) + { + throw UnknownReceptorType( receptor_type, get_name() ); + } + + return B_.logger_.connect_logging_device( dlr, recordablesMap_ ); +} + +inline void +eprop_iaf_adapt_bsshslm_2020::get_status( DictionaryDatum& d ) const +{ + P_.get( d ); + S_.get( d, P_ ); + ( *d )[ names::recordables ] = recordablesMap_.get_list(); +} + +inline void +eprop_iaf_adapt_bsshslm_2020::set_status( const DictionaryDatum& d ) +{ + // temporary copies in case of errors + Parameters_ ptmp = P_; + State_ stmp = S_; + + // make sure that ptmp and stmp consistent - throw BadProperty if not + const double delta_EL = ptmp.set( d, this ); + stmp.set( d, ptmp, delta_EL, this ); + + P_ = ptmp; + S_ = stmp; +} + +} // namespace nest + +#endif // EPROP_IAF_ADAPT_BSSHSLM_2020_H diff --git a/models/eprop_iaf_bsshslm_2020.cpp b/models/eprop_iaf_bsshslm_2020.cpp new file mode 100644 index 0000000000..108ea1e71a --- /dev/null +++ b/models/eprop_iaf_bsshslm_2020.cpp @@ -0,0 +1,472 @@ +/* + * eprop_iaf_bsshslm_2020.cpp + * + * This file is part of NEST. + * + * Copyright (C) 2004 The NEST Initiative + * + * NEST is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 2 of the License, or + * (at your option) any later version. + * + * NEST is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with NEST. If not, see . + * + */ + +// nest models +#include "eprop_iaf_bsshslm_2020.h" + +// C++ +#include + +// libnestutil +#include "dict_util.h" +#include "numerics.h" + +// nestkernel +#include "exceptions.h" +#include "kernel_manager.h" +#include "nest_impl.h" +#include "universal_data_logger_impl.h" + +// sli +#include "dictutils.h" + +namespace nest +{ + +void +register_eprop_iaf_bsshslm_2020( const std::string& name ) +{ + register_node_model< eprop_iaf_bsshslm_2020 >( name ); +} + +/* ---------------------------------------------------------------- + * Recordables map + * ---------------------------------------------------------------- */ + +RecordablesMap< eprop_iaf_bsshslm_2020 > eprop_iaf_bsshslm_2020::recordablesMap_; + +template <> +void +RecordablesMap< eprop_iaf_bsshslm_2020 >::create() +{ + insert_( names::learning_signal, &eprop_iaf_bsshslm_2020::get_learning_signal_ ); + insert_( names::surrogate_gradient, &eprop_iaf_bsshslm_2020::get_surrogate_gradient_ ); + insert_( names::V_m, &eprop_iaf_bsshslm_2020::get_v_m_ ); +} + +/* ---------------------------------------------------------------- + * Default constructors for parameters, state, and buffers + * ---------------------------------------------------------------- */ + +eprop_iaf_bsshslm_2020::Parameters_::Parameters_() + : C_m_( 250.0 ) + , c_reg_( 0.0 ) + , E_L_( -70.0 ) + , f_target_( 0.01 ) + , gamma_( 0.3 ) + , I_e_( 0.0 ) + , regular_spike_arrival_( true ) + , surrogate_gradient_function_( "piecewise_linear" ) + , t_ref_( 2.0 ) + , tau_m_( 10.0 ) + , V_min_( -std::numeric_limits< double >::max() ) + , V_th_( -55.0 - E_L_ ) +{ +} + +eprop_iaf_bsshslm_2020::State_::State_() + : learning_signal_( 0.0 ) + , r_( 0 ) + , surrogate_gradient_( 0.0 ) + , i_in_( 0.0 ) + , v_m_( 0.0 ) + , z_( 0.0 ) + , z_in_( 0.0 ) +{ +} + +eprop_iaf_bsshslm_2020::Buffers_::Buffers_( eprop_iaf_bsshslm_2020& n ) + : logger_( n ) +{ +} + +eprop_iaf_bsshslm_2020::Buffers_::Buffers_( const Buffers_&, eprop_iaf_bsshslm_2020& n ) + : logger_( n ) +{ +} + +/* ---------------------------------------------------------------- + * Getter and setter functions for parameters and state + * ---------------------------------------------------------------- */ + +void +eprop_iaf_bsshslm_2020::Parameters_::get( DictionaryDatum& d ) const +{ + def< double >( d, names::C_m, C_m_ ); + def< double >( d, names::c_reg, c_reg_ ); + def< double >( d, names::E_L, E_L_ ); + def< double >( d, names::f_target, f_target_ ); + def< double >( d, names::gamma, gamma_ ); + def< double >( d, names::I_e, I_e_ ); + def< bool >( d, names::regular_spike_arrival, regular_spike_arrival_ ); + def< std::string >( d, names::surrogate_gradient_function, surrogate_gradient_function_ ); + def< double >( d, names::t_ref, t_ref_ ); + def< double >( d, names::tau_m, tau_m_ ); + def< double >( d, names::V_min, V_min_ + E_L_ ); + def< double >( d, names::V_th, V_th_ + E_L_ ); +} + +double +eprop_iaf_bsshslm_2020::Parameters_::set( const DictionaryDatum& d, Node* node ) +{ + // if leak potential is changed, adjust all variables defined relative to it + const double ELold = E_L_; + updateValueParam< double >( d, names::E_L, E_L_, node ); + const double delta_EL = E_L_ - ELold; + + V_th_ -= updateValueParam< double >( d, names::V_th, V_th_, node ) ? E_L_ : delta_EL; + V_min_ -= updateValueParam< double >( d, names::V_min, V_min_, node ) ? E_L_ : delta_EL; + + updateValueParam< double >( d, names::C_m, C_m_, node ); + updateValueParam< double >( d, names::c_reg, c_reg_, node ); + + if ( updateValueParam< double >( d, names::f_target, f_target_, node ) ) + { + f_target_ /= 1000.0; // convert from spikes/s to spikes/ms + } + + updateValueParam< double >( d, names::gamma, gamma_, node ); + updateValueParam< double >( d, names::I_e, I_e_, node ); + updateValueParam< bool >( d, names::regular_spike_arrival, regular_spike_arrival_, node ); + updateValueParam< std::string >( d, names::surrogate_gradient_function, surrogate_gradient_function_, node ); + updateValueParam< double >( d, names::t_ref, t_ref_, node ); + updateValueParam< double >( d, names::tau_m, tau_m_, node ); + + if ( C_m_ <= 0 ) + { + throw BadProperty( "Membrane capacitance C_m > 0 required." ); + } + + if ( c_reg_ < 0 ) + { + throw BadProperty( "Firing rate regularization prefactor c_reg ≥ 0 required." ); + } + + if ( f_target_ < 0 ) + { + throw BadProperty( "Firing rate regularization target rate f_target ≥ 0 required." ); + } + + if ( gamma_ < 0.0 or 1.0 <= gamma_ ) + { + throw BadProperty( "Surrogate gradient / pseudo-derivative scaling gamma from interval [0,1) required." ); + } + + if ( surrogate_gradient_function_ != "piecewise_linear" ) + { + throw BadProperty( + "Surrogate gradient / pseudo derivate function surrogate_gradient_function from [\"piecewise_linear\"] " + "required." ); + } + + if ( tau_m_ <= 0 ) + { + throw BadProperty( "Membrane time constant tau_m > 0 required." ); + } + + if ( t_ref_ < 0 ) + { + throw BadProperty( "Refractory time t_ref ≥ 0 required." ); + } + + if ( surrogate_gradient_function_ == "piecewise_linear" and fabs( V_th_ ) < 1e-6 ) + { + throw BadProperty( + "Relative threshold voltage V_th-E_L ≠ 0 required if surrogate_gradient_function is \"piecewise_linear\"." ); + } + + if ( V_th_ < V_min_ ) + { + throw BadProperty( "Spike threshold voltage V_th ≥ minimal voltage V_min required." ); + } + + return delta_EL; +} + +void +eprop_iaf_bsshslm_2020::State_::get( DictionaryDatum& d, const Parameters_& p ) const +{ + def< double >( d, names::V_m, v_m_ + p.E_L_ ); + def< double >( d, names::surrogate_gradient, surrogate_gradient_ ); + def< double >( d, names::learning_signal, learning_signal_ ); +} + +void +eprop_iaf_bsshslm_2020::State_::set( const DictionaryDatum& d, const Parameters_& p, double delta_EL, Node* node ) +{ + v_m_ -= updateValueParam< double >( d, names::V_m, v_m_, node ) ? p.E_L_ : delta_EL; +} + +/* ---------------------------------------------------------------- + * Default and copy constructor for node + * ---------------------------------------------------------------- */ + +eprop_iaf_bsshslm_2020::eprop_iaf_bsshslm_2020() + : EpropArchivingNodeRecurrent() + , P_() + , S_() + , B_( *this ) +{ + recordablesMap_.create(); +} + +eprop_iaf_bsshslm_2020::eprop_iaf_bsshslm_2020( const eprop_iaf_bsshslm_2020& n ) + : EpropArchivingNodeRecurrent( n ) + , P_( n.P_ ) + , S_( n.S_ ) + , B_( n.B_, *this ) +{ +} + +/* ---------------------------------------------------------------- + * Node initialization functions + * ---------------------------------------------------------------- */ + +void +eprop_iaf_bsshslm_2020::init_buffers_() +{ + B_.spikes_.clear(); // includes resize + B_.currents_.clear(); // includes resize + B_.logger_.reset(); // includes resize +} + +void +eprop_iaf_bsshslm_2020::pre_run_hook() +{ + B_.logger_.init(); // ensures initialization in case multimeter connected after Simulate + + V_.RefractoryCounts_ = Time( Time::ms( P_.t_ref_ ) ).get_steps(); + + if ( P_.surrogate_gradient_function_ == "piecewise_linear" ) + { + compute_surrogate_gradient = &eprop_iaf_bsshslm_2020::compute_piecewise_linear_derivative; + } + + // calculate the entries of the propagator matrix for the evolution of the state vector + + const double dt = Time::get_resolution().get_ms(); + + V_.P_v_m_ = std::exp( -dt / P_.tau_m_ ); // called alpha in reference [1] + V_.P_i_in_ = P_.tau_m_ / P_.C_m_ * ( 1.0 - V_.P_v_m_ ); + V_.P_z_in_ = P_.regular_spike_arrival_ ? 1.0 : 1.0 - V_.P_v_m_; +} + +long +eprop_iaf_bsshslm_2020::get_shift() const +{ + return offset_gen_ + delay_in_rec_; +} + +bool +eprop_iaf_bsshslm_2020::is_eprop_recurrent_node() const +{ + return true; +} + +/* ---------------------------------------------------------------- + * Update function + * ---------------------------------------------------------------- */ + +void +eprop_iaf_bsshslm_2020::update( Time const& origin, const long from, const long to ) +{ + const long update_interval = kernel().simulation_manager.get_eprop_update_interval().get_steps(); + const bool with_reset = kernel().simulation_manager.get_eprop_reset_neurons_on_update(); + const long shift = get_shift(); + + for ( long lag = from; lag < to; ++lag ) + { + const long t = origin.get_steps() + lag; + const long interval_step = ( t - shift ) % update_interval; + + if ( interval_step == 0 ) + { + erase_used_firing_rate_reg_history(); + erase_used_update_history(); + erase_used_eprop_history(); + + if ( with_reset ) + { + S_.v_m_ = 0.0; + S_.r_ = 0; + S_.z_ = 0.0; + } + } + + S_.z_in_ = B_.spikes_.get_value( lag ); + + S_.v_m_ = V_.P_i_in_ * S_.i_in_ + V_.P_z_in_ * S_.z_in_ + V_.P_v_m_ * S_.v_m_; + S_.v_m_ -= P_.V_th_ * S_.z_; + S_.v_m_ = std::max( S_.v_m_, P_.V_min_ ); + + S_.z_ = 0.0; + + S_.surrogate_gradient_ = ( this->*compute_surrogate_gradient )(); + + write_surrogate_gradient_to_history( t, S_.surrogate_gradient_ ); + + if ( S_.v_m_ >= P_.V_th_ and S_.r_ == 0 ) + { + count_spike(); + + SpikeEvent se; + kernel().event_delivery_manager.send( *this, se, lag ); + + S_.z_ = 1.0; + + if ( V_.RefractoryCounts_ > 0 ) + { + S_.r_ = V_.RefractoryCounts_; + } + } + + if ( interval_step == update_interval - 1 ) + { + write_firing_rate_reg_to_history( t, P_.f_target_, P_.c_reg_ ); + reset_spike_count(); + } + + S_.learning_signal_ = get_learning_signal_from_history( t ); + + if ( S_.r_ > 0 ) + { + --S_.r_; + } + + S_.i_in_ = B_.currents_.get_value( lag ) + P_.I_e_; + + B_.logger_.record_data( t ); + } +} + +/* ---------------------------------------------------------------- + * Surrogate gradient functions + * ---------------------------------------------------------------- */ + +double +eprop_iaf_bsshslm_2020::compute_piecewise_linear_derivative() +{ + if ( S_.r_ > 0 ) + { + return 0.0; + } + + return P_.gamma_ * std::max( 0.0, 1.0 - std::fabs( ( S_.v_m_ - P_.V_th_ ) / P_.V_th_ ) ) / P_.V_th_; +} + +/* ---------------------------------------------------------------- + * Event handling functions + * ---------------------------------------------------------------- */ + +void +eprop_iaf_bsshslm_2020::handle( SpikeEvent& e ) +{ + assert( e.get_delay_steps() > 0 ); + + B_.spikes_.add_value( + e.get_rel_delivery_steps( kernel().simulation_manager.get_slice_origin() ), e.get_weight() * e.get_multiplicity() ); +} + +void +eprop_iaf_bsshslm_2020::handle( CurrentEvent& e ) +{ + assert( e.get_delay_steps() > 0 ); + + B_.currents_.add_value( + e.get_rel_delivery_steps( kernel().simulation_manager.get_slice_origin() ), e.get_weight() * e.get_current() ); +} + +void +eprop_iaf_bsshslm_2020::handle( LearningSignalConnectionEvent& e ) +{ + for ( auto it_event = e.begin(); it_event != e.end(); ) + { + const long time_step = e.get_stamp().get_steps(); + const double weight = e.get_weight(); + const double error_signal = e.get_coeffvalue( it_event ); // get_coeffvalue advances iterator + const double learning_signal = weight * error_signal; + + write_learning_signal_to_history( time_step, learning_signal ); + } +} + +void +eprop_iaf_bsshslm_2020::handle( DataLoggingRequest& e ) +{ + B_.logger_.handle( e ); +} + +double +eprop_iaf_bsshslm_2020::compute_gradient( std::vector< long >& presyn_isis, + const long t_previous_update, + const long t_previous_trigger_spike, + const double kappa, + const bool average_gradient ) +{ + auto eprop_hist_it = get_eprop_history( t_previous_trigger_spike ); + + double e = 0.0; // eligibility trace + double e_bar = 0.0; // low-pass filtered eligibility trace + double grad = 0.0; // gradient value to be calculated + double L = 0.0; // learning signal + double psi = 0.0; // surrogate gradient + double sum_e = 0.0; // sum of eligibility traces + double z = 0.0; // spiking variable + double z_bar = 0.0; // low-pass filtered spiking variable + + for ( long presyn_isi : presyn_isis ) + { + z = 1.0; // set spiking variable to 1 for each incoming spike + + for ( long t = 0; t < presyn_isi; ++t ) + { + assert( eprop_hist_it != eprop_history_.end() ); + + psi = eprop_hist_it->surrogate_gradient_; + L = eprop_hist_it->learning_signal_; + + z_bar = V_.P_v_m_ * z_bar + V_.P_z_in_ * z; + e = psi * z_bar; + e_bar = kappa * e_bar + ( 1.0 - kappa ) * e; + grad += L * e_bar; + sum_e += e; + z = 0.0; // set spiking variable to 0 between spikes + + ++eprop_hist_it; + } + } + presyn_isis.clear(); + + const long learning_window = kernel().simulation_manager.get_eprop_learning_window().get_steps(); + if ( average_gradient ) + { + grad /= learning_window; + } + + const long update_interval = kernel().simulation_manager.get_eprop_update_interval().get_steps(); + const auto it_reg_hist = get_firing_rate_reg_history( t_previous_update + get_shift() + update_interval ); + grad += it_reg_hist->firing_rate_reg_ * sum_e; + + return grad; +} + +} // namespace nest diff --git a/models/eprop_iaf_bsshslm_2020.h b/models/eprop_iaf_bsshslm_2020.h new file mode 100644 index 0000000000..2a7f2d96b1 --- /dev/null +++ b/models/eprop_iaf_bsshslm_2020.h @@ -0,0 +1,540 @@ +/* + * eprop_iaf_bsshslm_2020.h + * + * This file is part of NEST. + * + * Copyright (C) 2004 The NEST Initiative + * + * NEST is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 2 of the License, or + * (at your option) any later version. + * + * NEST is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with NEST. If not, see . + * + */ + +#ifndef EPROP_IAF_BSSHSLM_2020_H +#define EPROP_IAF_BSSHSLM_2020_H + +// nestkernel +#include "connection.h" +#include "eprop_archiving_node.h" +#include "eprop_archiving_node_impl.h" +#include "event.h" +#include "nest_types.h" +#include "ring_buffer.h" +#include "universal_data_logger.h" + +namespace nest +{ + +/* BeginUserDocs: neuron, e-prop plasticity, current-based, integrate-and-fire + +Short description ++++++++++++++++++ + +Current-based leaky integrate-and-fire neuron model with delta-shaped +postsynaptic currents for e-prop plasticity + +Description ++++++++++++ + +``eprop_iaf_bsshslm_2020`` is an implementation of a leaky integrate-and-fire +neuron model with delta-shaped postsynaptic currents used for eligibility +propagation (e-prop) plasticity. + +E-prop plasticity was originally introduced and implemented in TensorFlow in [1]_. + +The suffix ``_bsshslm_2020`` follows the NEST convention to indicate in the +model name the paper that introduced it by the first letter of the authors' last +names and the publication year. + +.. note:: + The neuron dynamics of the ``eprop_iaf_bsshslm_2020`` model (excluding e-prop + plasticity) are similar to the neuron dynamics of the ``iaf_psc_delta`` model, + with minor differences, such as the propagator of the post-synaptic current + and the voltage reset upon a spike. + +The membrane voltage time course :math:`v_j^t` of the neuron :math:`j` is given by: + +.. math:: + v_j^t &= \alpha v_j^{t-1}+\sum_{i \neq j}W_{ji}^\mathrm{rec}z_i^{t-1} + + \sum_i W_{ji}^\mathrm{in}x_i^t-z_j^{t-1}v_\mathrm{th} \,, \\ + \alpha &= e^{-\frac{\Delta t}{\tau_\mathrm{m}}} \,, + +whereby :math:`W_{ji}^\mathrm{rec}` and :math:`W_{ji}^\mathrm{in}` are the recurrent and +input synaptic weights, and :math:`z_i^{t-1}` and :math:`x_i^t` are the +recurrent and input presynaptic spike state variables, respectively. + +Descriptions of further parameters and variables can be found in the table below. + +The spike state variable is expressed by a Heaviside function: + +.. math:: + z_j^t = H\left(v_j^t-v_\mathrm{th}\right) \,. + +If the membrane voltage crosses the threshold voltage :math:`v_\text{th}`, a spike is +emitted and the membrane voltage is reduced by :math:`v_\text{th}` in the next +time step. After the time step of the spike emission, the neuron is not +able to spike for an absolute refractory period :math:`t_\text{ref}`. + +An additional state variable and the corresponding differential equation +represents a piecewise constant external current. + +Furthermore, the pseudo derivative of the membrane voltage needed for e-prop +plasticity is calculated: + +.. math:: + \psi_j^t = \frac{\gamma}{v_\text{th}} \text{max} + \left(0, 1-\left| \frac{v_j^t-v_\mathrm{th}}{v_\text{th}}\right| \right) \,. + +See the documentation on the ``iaf_psc_delta`` neuron model for more information +on the integration of the subthreshold dynamics. + +The change of the synaptic weight is calculated from the gradient :math:`g` of +the loss :math:`E` with respect to the synaptic weight :math:`W_{ji}`: +:math:`\frac{\mathrm{d}{E}}{\mathrm{d}{W_{ij}}}=g` +which depends on the presynaptic +spikes :math:`z_i^{t-1}`, the surrogate-gradient / pseudo-derivative of the postsynaptic membrane +voltage :math:`\psi_j^t` (which together form the eligibility trace +:math:`e_{ji}^t`), and the learning signal :math:`L_j^t` emitted by the readout +neurons. + +.. math:: + \frac{\mathrm{d}E}{\mathrm{d}W_{ji}} = g &= \sum_t L_j^t \bar{e}_{ji}^t, \\ + e_{ji}^t &= \psi^t_j \bar{z}_i^{t-1}\,, \\ + +The eligibility trace and the presynaptic spike trains are low-pass filtered +with some exponential kernels: + +.. math:: + \bar{e}_{ji}^t &= \mathcal{F}_\kappa(e_{ji}^t) \;\text{with}\, \kappa=e^{-\frac{\Delta t}{ + \tau_\text{m,out}}}\,,\\ + \bar{z}_i^t&=\mathcal{F}_\alpha(z_i^t)\,,\\ + \mathcal{F}_\alpha(z_i^t) &= \alpha\, \mathcal{F}_\alpha(z_i^{t-1}) + z_i^t + \;\text{with}\, \mathcal{F}_\alpha(z_i^0)=z_i^0\,, + +whereby :math:`\tau_\text{m,out}` is the membrane time constant of the readout neuron. + +Furthermore, a firing rate regularization mechanism keeps the average firing +rate :math:`f^\text{av}_j` of the postsynaptic neuron close to a target firing rate +:math:`f^\text{target}`. The gradient :math:`g^\text{reg}` of the regularization loss :math:`E^\text{reg}` +with respect to the synaptic weight :math:`W_{ji}` is given by: + +.. math:: + \frac{\mathrm{d}E^\text{reg}}{\mathrm{d}W_{ji}} = g^\text{reg} = c_\text{reg} + \sum_t \frac{1}{Tn_\text{trial}} \left( f^\text{target}-f^\text{av}_j\right)e_{ji}^t\,, + +whereby :math:`c_\text{reg}` scales the overall regularization and the average +is taken over the time that passed since the previous update, that is, the number of +trials :math:`n_\text{trial}` times the duration of an update interval :math:`T`. + +The overall gradient is given by the addition of the two gradients. + +For more information on e-prop plasticity, see the documentation on the other e-prop models: + + * :doc:`eprop_iaf_adapt_bsshslm_2020<../models/eprop_iaf_adapt_bsshslm_2020/>` + * :doc:`eprop_readout_bsshslm_2020<../models/eprop_readout_bsshslm_2020/>` + * :doc:`eprop_synapse_bsshslm_2020<../models/eprop_synapse_bsshslm_2020/>` + * :doc:`eprop_learning_signal_connection_bsshslm_2020<../models/eprop_learning_signal_connection_bsshslm_2020/>` + +Details on the event-based NEST implementation of e-prop can be found in [2]_. + +Parameters +++++++++++ + +The following parameters can be set in the status dictionary. + +=========================== ======= ======================= ================ =================================== +**Neuron parameters** +---------------------------------------------------------------------------------------------------------------- +Parameter Unit Math equivalent Default Description +=========================== ======= ======================= ================ =================================== +C_m pF :math:`C_\text{m}` 250.0 Capacitance of the membrane +c_reg :math:`c_\text{reg}` 0.0 Prefactor of firing rate + regularization +E_L mV :math:`E_\text{L}` -70.0 Leak / resting membrane potential +f_target Hz :math:`f^\text{target}` 10.0 Target firing rate of rate + regularization +gamma :math:`\gamma` 0.3 Scaling of surrogate gradient / + pseudo-derivative of membrane + voltage +I_e pA :math:`I_\text{e}` 0.0 Constant external input current +regular_spike_arrival Boolean True If True, the input spikes arrive at + the end of the time step, if + False at the beginning (determines + PSC scale) +surrogate_gradient_function :math:`\psi` piecewise_linear Surrogate gradient / + pseudo-derivative function + ["piecewise_linear"] +t_ref ms :math:`t_\text{ref}` 2.0 Duration of the refractory period +tau_m ms :math:`\tau_\text{m}` 10.0 Time constant of the membrane +V_min mV :math:`v_\text{min}` -1.79e+308 Absolute lower bound of the + membrane voltage +V_th mV :math:`v_\text{th}` -55.0 Spike threshold voltage +=========================== ======= ======================= ================ =================================== + +The following state variables evolve during simulation. + +================== ==== =============== ============= ========================================================== +**Neuron state variables and recordables** +---------------------------------------------------------------------------------------------------------------- +State variable Unit Math equivalent Initial value Description +================== ==== =============== ============= ========================================================== +learning_signal pA :math:`L_j` 0.0 Learning signal +surrogate_gradient :math:`\psi_j` 0.0 Surrogate gradient / pseudo-derivative of membrane voltage +V_m mV :math:`v_j` -70.0 Membrane voltage +================== ==== =============== ============= ========================================================== + +Recordables ++++++++++++ + +The following variables can be recorded: + + - learning signal ``learning_signal`` + - membrane potential ``V_m`` + - surrogate gradient ``surrogate_gradient`` + +Usage ++++++ + +This model can only be used in combination with the other e-prop models, +whereby the network architecture requires specific wiring, input, and output. +The usage is demonstrated in several +:doc:`supervised regression and classification tasks <../auto_examples/eprop_plasticity/index>` +reproducing among others the original proof-of-concept tasks in [1]_. + +References +++++++++++ + +.. [1] Bellec G, Scherr F, Subramoney F, Hajek E, Salaj D, Legenstein R, + Maass W (2020). A solution to the learning dilemma for recurrent + networks of spiking neurons. Nature Communications, 11:3625. + https://doi.org/10.1038/s41467-020-17236-y +.. [2] Korcsak-Gorzo A, Stapmanns J, Espinoza Valverde JA, Dahmen D, + van Albada SJ, Bolten M, Diesmann M. Event-based implementation of + eligibility propagation (in preparation) + +Sends +++++++++ + +SpikeEvent + +Receives +++++++++ + +SpikeEvent, CurrentEvent, LearningSignalConnectionEvent, DataLoggingRequest + +See also +++++++++ + +Examples using this model +++++++++++++++++++++++++++ + +.. listexamples:: eprop_iaf_bsshslm_2020 + +EndUserDocs */ + +void register_eprop_iaf_bsshslm_2020( const std::string& name ); + +/** + * Class implementing a current-based leaky integrate-and-fire neuron model with delta-shaped postsynaptic currents for + * e-prop plasticity according to Bellec et al (2020). + */ +class eprop_iaf_bsshslm_2020 : public EpropArchivingNodeRecurrent +{ + +public: + //! Default constructor. + eprop_iaf_bsshslm_2020(); + + //! Copy constructor. + eprop_iaf_bsshslm_2020( const eprop_iaf_bsshslm_2020& ); + + using Node::handle; + using Node::handles_test_event; + + size_t send_test_event( Node&, size_t, synindex, bool ) override; + + void handle( SpikeEvent& ) override; + void handle( CurrentEvent& ) override; + void handle( LearningSignalConnectionEvent& ) override; + void handle( DataLoggingRequest& ) override; + + size_t handles_test_event( SpikeEvent&, size_t ) override; + size_t handles_test_event( CurrentEvent&, size_t ) override; + size_t handles_test_event( LearningSignalConnectionEvent&, size_t ) override; + size_t handles_test_event( DataLoggingRequest&, size_t ) override; + + void get_status( DictionaryDatum& ) const override; + void set_status( const DictionaryDatum& ) override; + + double compute_gradient( std::vector< long >& presyn_isis, + const long t_previous_update, + const long t_previous_trigger_spike, + const double kappa, + const bool average_gradient ) override; + + void pre_run_hook() override; + long get_shift() const override; + bool is_eprop_recurrent_node() const override; + void update( Time const&, const long, const long ) override; + +protected: + void init_buffers_() override; + +private: + //! Compute the piecewise linear surrogate gradient. + double compute_piecewise_linear_derivative(); + + //! Compute the surrogate gradient. + double ( eprop_iaf_bsshslm_2020::*compute_surrogate_gradient )(); + + //! Map for storing a static set of recordables. + friend class RecordablesMap< eprop_iaf_bsshslm_2020 >; + + //! Logger for universal data supporting the data logging request / reply mechanism. Populated with a recordables map. + friend class UniversalDataLogger< eprop_iaf_bsshslm_2020 >; + + //! Structure of parameters. + struct Parameters_ + { + //! Capacitance of the membrane (pF). + double C_m_; + + //! Prefactor of firing rate regularization. + double c_reg_; + + //! Leak / resting membrane potential (mV). + double E_L_; + + //! Target firing rate of rate regularization (spikes/s). + double f_target_; + + //! Scaling of surrogate-gradient / pseudo-derivative of membrane voltage. + double gamma_; + + //! Constant external input current (pA). + double I_e_; + + //! If True, the input spikes arrive at the beginning of the time step, if False at the end (determines PSC scale). + bool regular_spike_arrival_; + + //! Surrogate gradient / pseudo-derivative function ["piecewise_linear"]. + std::string surrogate_gradient_function_; + + //! Duration of the refractory period (ms). + double t_ref_; + + //! Time constant of the membrane (ms). + double tau_m_; + + //! Absolute lower bound of the membrane voltage relative to the leak membrane potential (mV). + double V_min_; + + //! Spike threshold voltage relative to the leak membrane potential (mV). + double V_th_; + + //! Default constructor. + Parameters_(); + + //! Get the parameters and their values. + void get( DictionaryDatum& ) const; + + //! Set the parameters and throw errors in case of invalid values. + double set( const DictionaryDatum&, Node* ); + }; + + //! Structure of state variables. + struct State_ + { + //! Learning signal. Sum of weighted error signals coming from the readout neurons. + double learning_signal_; + + //! Number of remaining refractory steps. + int r_; + + //! Surrogate gradient / pseudo-derivative of the membrane voltage. + double surrogate_gradient_; + + //! Input current (pA). + double i_in_; + + //! Membrane voltage relative to the leak membrane potential (mV). + double v_m_; + + //! Binary spike variable - 1.0 if the neuron has spiked in the previous time step and 0.0 otherwise. + double z_; + + //! Binary input spike variables - 1.0 if the neuron has spiked in the previous time step and 0.0 otherwise. + double z_in_; + + //! Default constructor. + State_(); + + //! Get the state variables and their values. + void get( DictionaryDatum&, const Parameters_& ) const; + + //! Set the state variables. + void set( const DictionaryDatum&, const Parameters_&, double, Node* ); + }; + + //! Structure of buffers. + struct Buffers_ + { + //! Default constructor. + Buffers_( eprop_iaf_bsshslm_2020& ); + + //! Copy constructor. + Buffers_( const Buffers_&, eprop_iaf_bsshslm_2020& ); + + //! Buffer for incoming spikes. + RingBuffer spikes_; + + //! Buffer for incoming currents. + RingBuffer currents_; + + //! Logger for universal data. + UniversalDataLogger< eprop_iaf_bsshslm_2020 > logger_; + }; + + //! Structure of general variables. + struct Variables_ + { + //! Propagator matrix entry for evolving the membrane voltage. + double P_v_m_; + + //! Propagator matrix entry for evolving the incoming spike variables. + double P_z_in_; + + //! Propagator matrix entry for evolving the incoming currents. + double P_i_in_; + + //! Total refractory steps. + int RefractoryCounts_; + }; + + //! Get the current value of the membrane voltage. + double + get_v_m_() const + { + return S_.v_m_ + P_.E_L_; + } + + //! Get the current value of the surrogate gradient. + double + get_surrogate_gradient_() const + { + return S_.surrogate_gradient_; + } + + //! Get the current value of the learning signal. + double + get_learning_signal_() const + { + return S_.learning_signal_; + } + + // the order in which the structure instances are defined is important for speed + + //!< Structure of parameters. + Parameters_ P_; + + //!< Structure of state variables. + State_ S_; + + //!< Structure of general variables. + Variables_ V_; + + //!< Structure of buffers. + Buffers_ B_; + + //! Map storing a static set of recordables. + static RecordablesMap< eprop_iaf_bsshslm_2020 > recordablesMap_; +}; + +inline size_t +eprop_iaf_bsshslm_2020::send_test_event( Node& target, size_t receptor_type, synindex, bool ) +{ + SpikeEvent e; + e.set_sender( *this ); + return target.handles_test_event( e, receptor_type ); +} + +inline size_t +eprop_iaf_bsshslm_2020::handles_test_event( SpikeEvent&, size_t receptor_type ) +{ + if ( receptor_type != 0 ) + { + throw UnknownReceptorType( receptor_type, get_name() ); + } + + return 0; +} + +inline size_t +eprop_iaf_bsshslm_2020::handles_test_event( CurrentEvent&, size_t receptor_type ) +{ + if ( receptor_type != 0 ) + { + throw UnknownReceptorType( receptor_type, get_name() ); + } + + return 0; +} + +inline size_t +eprop_iaf_bsshslm_2020::handles_test_event( LearningSignalConnectionEvent&, size_t receptor_type ) +{ + if ( receptor_type != 0 ) + { + throw UnknownReceptorType( receptor_type, get_name() ); + } + + return 0; +} + +inline size_t +eprop_iaf_bsshslm_2020::handles_test_event( DataLoggingRequest& dlr, size_t receptor_type ) +{ + if ( receptor_type != 0 ) + { + throw UnknownReceptorType( receptor_type, get_name() ); + } + + return B_.logger_.connect_logging_device( dlr, recordablesMap_ ); +} + +inline void +eprop_iaf_bsshslm_2020::get_status( DictionaryDatum& d ) const +{ + P_.get( d ); + S_.get( d, P_ ); + ( *d )[ names::recordables ] = recordablesMap_.get_list(); +} + +inline void +eprop_iaf_bsshslm_2020::set_status( const DictionaryDatum& d ) +{ + // temporary copies in case of errors + Parameters_ ptmp = P_; + State_ stmp = S_; + + // make sure that ptmp and stmp consistent - throw BadProperty if not + const double delta_EL = ptmp.set( d, this ); + stmp.set( d, ptmp, delta_EL, this ); + + P_ = ptmp; + S_ = stmp; +} + +} // namespace nest + +#endif // EPROP_IAF_BSSHSLM_2020_H diff --git a/models/eprop_learning_signal_connection_bsshslm_2020.cpp b/models/eprop_learning_signal_connection_bsshslm_2020.cpp new file mode 100644 index 0000000000..fe29c2a84c --- /dev/null +++ b/models/eprop_learning_signal_connection_bsshslm_2020.cpp @@ -0,0 +1,32 @@ +/* + * eprop_learning_signal_connection_bsshslm_2020.cpp + * + * This file is part of NEST. + * + * Copyright (C) 2004 The NEST Initiative + * + * NEST is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 2 of the License, or + * (at your option) any later version. + * + * NEST is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with NEST. If not, see . + * + */ + +#include "eprop_learning_signal_connection_bsshslm_2020.h" + +// nestkernel +#include "nest_impl.h" + +void +nest::register_eprop_learning_signal_connection_bsshslm_2020( const std::string& name ) +{ + register_connection_model< eprop_learning_signal_connection_bsshslm_2020 >( name ); +} diff --git a/models/eprop_learning_signal_connection_bsshslm_2020.h b/models/eprop_learning_signal_connection_bsshslm_2020.h new file mode 100644 index 0000000000..98ad6687cf --- /dev/null +++ b/models/eprop_learning_signal_connection_bsshslm_2020.h @@ -0,0 +1,225 @@ +/* + * eprop_learning_signal_connection_bsshslm_2020.h + * + * This file is part of NEST. + * + * Copyright (C) 2004 The NEST Initiative + * + * NEST is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 2 of the License, or + * (at your option) any later version. + * + * NEST is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with NEST. If not, see . + * + */ + + +#ifndef EPROP_LEARNING_SIGNAL_CONNECTION_BSSHSLM_2020_H +#define EPROP_LEARNING_SIGNAL_CONNECTION_BSSHSLM_2020_H + +// nestkernel +#include "connection.h" + +namespace nest +{ + +/* BeginUserDocs: synapse, e-prop plasticity + +Short description ++++++++++++++++++ + +Synapse model transmitting feedback learning signals for e-prop plasticity + +Description ++++++++++++ + +``eprop_learning_signal_connection_bsshslm_2020`` is an implementation of a feedback connector from +``eprop_readout_bsshslm_2020`` readout neurons to ``eprop_iaf_bsshslm_2020`` or ``eprop_iaf_adapt_bsshslm_2020`` +recurrent neurons that transmits the learning signals :math:`L_j^t` for eligibility propagation (e-prop) plasticity and +has a static weight :math:`B_{jk}`. + +E-prop plasticity was originally introduced and implemented in TensorFlow in [1]_. + +The suffix ``_bsshslm_2020`` follows the NEST convention to indicate in the +model name the paper that introduced it by the first letter of the authors' last +names and the publication year. + +For more information on e-prop plasticity, see the documentation on the other e-prop models: + + * :doc:`eprop_iaf_bsshslm_2020<../models/eprop_iaf_bsshslm_2020/>` + * :doc:`eprop_iaf_adapt_bsshslm_2020<../models/eprop_iaf_adapt_bsshslm_2020/>` + * :doc:`eprop_readout_bsshslm_2020<../models/eprop_readout_bsshslm_2020/>` + * :doc:`eprop_synapse_bsshslm_2020<../models/eprop_synapse_bsshslm_2020/>` + +Details on the event-based NEST implementation of e-prop can be found in [2]_. + +Parameters +++++++++++ + +The following parameters can be set in the status dictionary. + +========= ===== ================ ======= =============== +**Individual synapse parameters** +-------------------------------------------------------- +Parameter Unit Math equivalent Default Description +========= ===== ================ ======= =============== +delay ms :math:`d_{jk}` 1.0 Dendritic delay +weight pA :math:`B_{jk}` 1.0 Synaptic weight +========= ===== ================ ======= =============== + +Recordables ++++++++++++ + +The following variables can be recorded: + + - synaptic weight ``weight`` + +Usage ++++++ + +This model can only be used in combination with the other e-prop models, +whereby the network architecture requires specific wiring, input, and output. +The usage is demonstrated in several +:doc:`supervised regression and classification tasks <../auto_examples/eprop_plasticity/index>` +reproducing among others the original proof-of-concept tasks in [1]_. + +Transmits ++++++++++ + +LearningSignalConnectionEvent + +References +++++++++++ + +.. [1] Bellec G, Scherr F, Subramoney F, Hajek E, Salaj D, Legenstein R, + Maass W (2020). A solution to the learning dilemma for recurrent + networks of spiking neurons. Nature Communications, 11:3625. + https://doi.org/10.1038/s41467-020-17236-y +.. [2] Korcsak-Gorzo A, Stapmanns J, Espinoza Valverde JA, Dahmen D, + van Albada SJ, Bolten M, Diesmann M. Event-based implementation of + eligibility propagation (in preparation) + +See also +++++++++ + +Examples using this model +++++++++++++++++++++++++++ + +.. listexamples:: eprop_learning_signal_connection_bsshslm_2020 + +EndUserDocs */ + +void register_eprop_learning_signal_connection_bsshslm_2020( const std::string& name ); + +/** + * Class implementing a synapse model transmitting secondary feedback learning signals for e-prop plasticity + * according to Bellec et al. (2020). + */ +template < typename targetidentifierT > +class eprop_learning_signal_connection_bsshslm_2020 : public Connection< targetidentifierT > +{ + +public: + //! Type of the common synapse properties. + typedef CommonSynapseProperties CommonPropertiesType; + + //! Type of the connection base. + typedef Connection< targetidentifierT > ConnectionBase; + + //! Properties of the connection model. + static constexpr ConnectionModelProperties properties = ConnectionModelProperties::HAS_DELAY; + + //! Default constructor. + eprop_learning_signal_connection_bsshslm_2020() + : ConnectionBase() + , weight_( 1.0 ) + { + } + + //! Get the secondary learning signal event. + SecondaryEvent* get_secondary_event(); + + using ConnectionBase::get_delay_steps; + using ConnectionBase::get_rport; + using ConnectionBase::get_target; + + //! Check if the target accepts the event and receptor type requested by the sender. + void + check_connection( Node& s, Node& t, size_t receptor_type, const CommonPropertiesType& ) + { + LearningSignalConnectionEvent ge; + + s.sends_secondary_event( ge ); + ge.set_sender( s ); + Connection< targetidentifierT >::target_.set_rport( t.handles_test_event( ge, receptor_type ) ); + Connection< targetidentifierT >::target_.set_target( &t ); + } + + //! Send the learning signal event. + bool + send( Event& e, size_t t, const CommonSynapseProperties& ) + { + e.set_weight( weight_ ); + e.set_delay_steps( get_delay_steps() ); + e.set_receiver( *get_target( t ) ); + e.set_rport( get_rport() ); + e(); + return true; + } + + //! Get the model attributes and their values. + void get_status( DictionaryDatum& d ) const; + + //! Set the values of the model attributes. + void set_status( const DictionaryDatum& d, ConnectorModel& cm ); + + //! Set the synaptic weight to the provided value. + void + set_weight( const double w ) + { + weight_ = w; + } + +private: + //! Synaptic weight. + double weight_; +}; + +template < typename targetidentifierT > +constexpr ConnectionModelProperties eprop_learning_signal_connection_bsshslm_2020< targetidentifierT >::properties; + +template < typename targetidentifierT > +void +eprop_learning_signal_connection_bsshslm_2020< targetidentifierT >::get_status( DictionaryDatum& d ) const +{ + ConnectionBase::get_status( d ); + def< double >( d, names::weight, weight_ ); + def< long >( d, names::size_of, sizeof( *this ) ); +} + +template < typename targetidentifierT > +void +eprop_learning_signal_connection_bsshslm_2020< targetidentifierT >::set_status( const DictionaryDatum& d, + ConnectorModel& cm ) +{ + ConnectionBase::set_status( d, cm ); + updateValue< double >( d, names::weight, weight_ ); +} + +template < typename targetidentifierT > +SecondaryEvent* +eprop_learning_signal_connection_bsshslm_2020< targetidentifierT >::get_secondary_event() +{ + return new LearningSignalConnectionEvent(); +} + +} // namespace nest + +#endif // EPROP_LEARNING_SIGNAL_CONNECTION_BSSHSLM_2020_H diff --git a/models/eprop_readout_bsshslm_2020.cpp b/models/eprop_readout_bsshslm_2020.cpp new file mode 100644 index 0000000000..76317bc643 --- /dev/null +++ b/models/eprop_readout_bsshslm_2020.cpp @@ -0,0 +1,433 @@ +/* + * eprop_readout_bsshslm_2020.cpp + * + * This file is part of NEST. + * + * Copyright (C) 2004 The NEST Initiative + * + * NEST is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 2 of the License, or + * (at your option) any later version. + * + * NEST is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with NEST. If not, see . + * + */ + +// nest models +#include "eprop_readout_bsshslm_2020.h" + +// C++ +#include + +// libnestutil +#include "dict_util.h" +#include "numerics.h" + +// nestkernel +#include "exceptions.h" +#include "kernel_manager.h" +#include "nest_impl.h" +#include "universal_data_logger_impl.h" + +// sli +#include "dictutils.h" + +namespace nest +{ + +void +register_eprop_readout_bsshslm_2020( const std::string& name ) +{ + register_node_model< eprop_readout_bsshslm_2020 >( name ); +} + +/* ---------------------------------------------------------------- + * Recordables map + * ---------------------------------------------------------------- */ + +RecordablesMap< eprop_readout_bsshslm_2020 > eprop_readout_bsshslm_2020::recordablesMap_; + +template <> +void +RecordablesMap< eprop_readout_bsshslm_2020 >::create() +{ + insert_( names::error_signal, &eprop_readout_bsshslm_2020::get_error_signal_ ); + insert_( names::readout_signal, &eprop_readout_bsshslm_2020::get_readout_signal_ ); + insert_( names::readout_signal_unnorm, &eprop_readout_bsshslm_2020::get_readout_signal_unnorm_ ); + insert_( names::target_signal, &eprop_readout_bsshslm_2020::get_target_signal_ ); + insert_( names::V_m, &eprop_readout_bsshslm_2020::get_v_m_ ); +} + +/* ---------------------------------------------------------------- + * Default constructors for parameters, state, and buffers + * ---------------------------------------------------------------- */ + +eprop_readout_bsshslm_2020::Parameters_::Parameters_() + : C_m_( 250.0 ) + , E_L_( 0.0 ) + , I_e_( 0.0 ) + , loss_( "mean_squared_error" ) + , regular_spike_arrival_( true ) + , tau_m_( 10.0 ) + , V_min_( -std::numeric_limits< double >::max() ) +{ +} + +eprop_readout_bsshslm_2020::State_::State_() + : error_signal_( 0.0 ) + , readout_signal_( 0.0 ) + , readout_signal_unnorm_( 0.0 ) + , target_signal_( 0.0 ) + , i_in_( 0.0 ) + , v_m_( 0.0 ) + , z_in_( 0.0 ) +{ +} + +eprop_readout_bsshslm_2020::Buffers_::Buffers_( eprop_readout_bsshslm_2020& n ) + : logger_( n ) +{ +} + +eprop_readout_bsshslm_2020::Buffers_::Buffers_( const Buffers_&, eprop_readout_bsshslm_2020& n ) + : logger_( n ) +{ +} + +/* ---------------------------------------------------------------- + * Getter and setter functions for parameters and state + * ---------------------------------------------------------------- */ + +void +eprop_readout_bsshslm_2020::Parameters_::get( DictionaryDatum& d ) const +{ + def< double >( d, names::C_m, C_m_ ); + def< double >( d, names::E_L, E_L_ ); + def< double >( d, names::I_e, I_e_ ); + def< std::string >( d, names::loss, loss_ ); + def< bool >( d, names::regular_spike_arrival, regular_spike_arrival_ ); + def< double >( d, names::tau_m, tau_m_ ); + def< double >( d, names::V_min, V_min_ + E_L_ ); +} + +double +eprop_readout_bsshslm_2020::Parameters_::set( const DictionaryDatum& d, Node* node ) +{ + // if leak potential is changed, adjust all variables defined relative to it + const double ELold = E_L_; + updateValueParam< double >( d, names::E_L, E_L_, node ); + const double delta_EL = E_L_ - ELold; + + V_min_ -= updateValueParam< double >( d, names::V_min, V_min_, node ) ? E_L_ : delta_EL; + + updateValueParam< double >( d, names::C_m, C_m_, node ); + updateValueParam< double >( d, names::I_e, I_e_, node ); + updateValueParam< std::string >( d, names::loss, loss_, node ); + updateValueParam< bool >( d, names::regular_spike_arrival, regular_spike_arrival_, node ); + updateValueParam< double >( d, names::tau_m, tau_m_, node ); + + if ( C_m_ <= 0 ) + { + throw BadProperty( "Membrane capacitance C_m > 0 required." ); + } + + if ( loss_ != "mean_squared_error" and loss_ != "cross_entropy" ) + { + throw BadProperty( "Loss function loss from [\"mean_squared_error\", \"cross_entropy\"] required." ); + } + + if ( tau_m_ <= 0 ) + { + throw BadProperty( "Membrane time constant tau_m > 0 required." ); + } + + return delta_EL; +} + +void +eprop_readout_bsshslm_2020::State_::get( DictionaryDatum& d, const Parameters_& p ) const +{ + def< double >( d, names::V_m, v_m_ + p.E_L_ ); + def< double >( d, names::error_signal, error_signal_ ); + def< double >( d, names::readout_signal, readout_signal_ ); + def< double >( d, names::readout_signal_unnorm, readout_signal_unnorm_ ); + def< double >( d, names::target_signal, target_signal_ ); +} + +void +eprop_readout_bsshslm_2020::State_::set( const DictionaryDatum& d, const Parameters_& p, double delta_EL, Node* node ) +{ + v_m_ -= updateValueParam< double >( d, names::V_m, v_m_, node ) ? p.E_L_ : delta_EL; +} + +/* ---------------------------------------------------------------- + * Default and copy constructor for node + * ---------------------------------------------------------------- */ + +eprop_readout_bsshslm_2020::eprop_readout_bsshslm_2020() + : EpropArchivingNodeReadout() + , P_() + , S_() + , B_( *this ) +{ + recordablesMap_.create(); +} + +eprop_readout_bsshslm_2020::eprop_readout_bsshslm_2020( const eprop_readout_bsshslm_2020& n ) + : EpropArchivingNodeReadout( n ) + , P_( n.P_ ) + , S_( n.S_ ) + , B_( n.B_, *this ) +{ +} + +/* ---------------------------------------------------------------- + * Node initialization functions + * ---------------------------------------------------------------- */ + +void +eprop_readout_bsshslm_2020::init_buffers_() +{ + B_.normalization_rate_ = 0; + B_.spikes_.clear(); // includes resize + B_.currents_.clear(); // includes resize + B_.logger_.reset(); // includes resize +} + +void +eprop_readout_bsshslm_2020::pre_run_hook() +{ + B_.logger_.init(); // ensures initialization in case multimeter connected after Simulate + + if ( P_.loss_ == "mean_squared_error" ) + { + compute_error_signal = &eprop_readout_bsshslm_2020::compute_error_signal_mean_squared_error; + V_.signal_to_other_readouts_ = false; + } + else if ( P_.loss_ == "cross_entropy" ) + { + compute_error_signal = &eprop_readout_bsshslm_2020::compute_error_signal_cross_entropy; + V_.signal_to_other_readouts_ = true; + } + + const double dt = Time::get_resolution().get_ms(); + + V_.P_v_m_ = std::exp( -dt / P_.tau_m_ ); // called kappa in reference [1] + V_.P_i_in_ = P_.tau_m_ / P_.C_m_ * ( 1.0 - V_.P_v_m_ ); + V_.P_z_in_ = P_.regular_spike_arrival_ ? 1.0 : 1.0 - V_.P_v_m_; +} + +long +eprop_readout_bsshslm_2020::get_shift() const +{ + return offset_gen_ + delay_in_rec_ + delay_rec_out_; +} + +bool +eprop_readout_bsshslm_2020::is_eprop_recurrent_node() const +{ + return false; +} + +/* ---------------------------------------------------------------- + * Update function + * ---------------------------------------------------------------- */ + +void +eprop_readout_bsshslm_2020::update( Time const& origin, const long from, const long to ) +{ + const long update_interval = kernel().simulation_manager.get_eprop_update_interval().get_steps(); + const long learning_window = kernel().simulation_manager.get_eprop_learning_window().get_steps(); + const bool with_reset = kernel().simulation_manager.get_eprop_reset_neurons_on_update(); + const long shift = get_shift(); + + const size_t buffer_size = kernel().connection_manager.get_min_delay(); + + std::vector< double > error_signal_buffer( buffer_size, 0.0 ); + std::vector< double > readout_signal_unnorm_buffer( buffer_size, 0.0 ); + + for ( long lag = from; lag < to; ++lag ) + { + const long t = origin.get_steps() + lag; + const long interval_step = ( t - shift ) % update_interval; + const long interval_step_signals = ( t - shift - delay_out_norm_ ) % update_interval; + + + if ( interval_step == 0 ) + { + erase_used_update_history(); + erase_used_eprop_history(); + + if ( with_reset ) + { + S_.v_m_ = 0.0; + } + } + + S_.z_in_ = B_.spikes_.get_value( lag ); + + S_.v_m_ = V_.P_i_in_ * S_.i_in_ + V_.P_z_in_ * S_.z_in_ + V_.P_v_m_ * S_.v_m_; + S_.v_m_ = std::max( S_.v_m_, P_.V_min_ ); + + ( this->*compute_error_signal )( lag ); + + if ( interval_step_signals < update_interval - learning_window ) + { + S_.target_signal_ = 0.0; + S_.readout_signal_ = 0.0; + S_.error_signal_ = 0.0; + } + + B_.normalization_rate_ = 0.0; + + if ( V_.signal_to_other_readouts_ ) + { + readout_signal_unnorm_buffer[ lag ] = S_.readout_signal_unnorm_; + } + + error_signal_buffer[ lag ] = S_.error_signal_; + + write_error_signal_to_history( t, S_.error_signal_ ); + + S_.i_in_ = B_.currents_.get_value( lag ) + P_.I_e_; + + B_.logger_.record_data( t ); + } + + LearningSignalConnectionEvent error_signal_event; + error_signal_event.set_coeffarray( error_signal_buffer ); + kernel().event_delivery_manager.send_secondary( *this, error_signal_event ); + + if ( V_.signal_to_other_readouts_ ) + { + // time is one time step longer than the final interval_step to enable sending the + // unnormalized readout signal one time step in advance so that it is available + // in the next times step for computing the normalized readout signal + DelayedRateConnectionEvent readout_signal_unnorm_event; + readout_signal_unnorm_event.set_coeffarray( readout_signal_unnorm_buffer ); + kernel().event_delivery_manager.send_secondary( *this, readout_signal_unnorm_event ); + } + return; +} + +/* ---------------------------------------------------------------- + * Error signal functions + * ---------------------------------------------------------------- */ + +void +eprop_readout_bsshslm_2020::compute_error_signal_mean_squared_error( const long lag ) +{ + S_.readout_signal_ = S_.readout_signal_unnorm_; + S_.readout_signal_unnorm_ = S_.v_m_ + P_.E_L_; + S_.error_signal_ = S_.readout_signal_ - S_.target_signal_; +} + +void +eprop_readout_bsshslm_2020::compute_error_signal_cross_entropy( const long lag ) +{ + const double norm_rate = B_.normalization_rate_ + S_.readout_signal_unnorm_; + S_.readout_signal_ = S_.readout_signal_unnorm_ / norm_rate; + S_.readout_signal_unnorm_ = std::exp( S_.v_m_ + P_.E_L_ ); + S_.error_signal_ = S_.readout_signal_ - S_.target_signal_; +} + +/* ---------------------------------------------------------------- + * Event handling functions + * ---------------------------------------------------------------- */ + +void +eprop_readout_bsshslm_2020::handle( DelayedRateConnectionEvent& e ) +{ + const size_t rport = e.get_rport(); + assert( rport < SUP_RATE_RECEPTOR ); + + auto it = e.begin(); + assert( it != e.end() ); + + const double signal = e.get_weight() * e.get_coeffvalue( it ); + if ( rport == READOUT_SIG ) + { + B_.normalization_rate_ += signal; + } + else if ( rport == TARGET_SIG ) + { + S_.target_signal_ = signal; + } + + assert( it == e.end() ); +} + +void +eprop_readout_bsshslm_2020::handle( SpikeEvent& e ) +{ + assert( e.get_delay_steps() > 0 ); + + B_.spikes_.add_value( + e.get_rel_delivery_steps( kernel().simulation_manager.get_slice_origin() ), e.get_weight() * e.get_multiplicity() ); +} + +void +eprop_readout_bsshslm_2020::handle( CurrentEvent& e ) +{ + assert( e.get_delay_steps() > 0 ); + + B_.currents_.add_value( + e.get_rel_delivery_steps( kernel().simulation_manager.get_slice_origin() ), e.get_weight() * e.get_current() ); +} + +void +eprop_readout_bsshslm_2020::handle( DataLoggingRequest& e ) +{ + B_.logger_.handle( e ); +} + +double +eprop_readout_bsshslm_2020::compute_gradient( std::vector< long >& presyn_isis, + const long, + const long t_previous_trigger_spike, + const double kappa, + const bool average_gradient ) +{ + auto eprop_hist_it = get_eprop_history( t_previous_trigger_spike ); + + double grad = 0.0; // gradient value to be calculated + double L = 0.0; // error signal + double z = 0.0; // spiking variable + double z_bar = 0.0; // low-pass filtered spiking variable + + for ( long presyn_isi : presyn_isis ) + { + z = 1.0; // set spiking variable to 1 for each incoming spike + + for ( long t = 0; t < presyn_isi; ++t ) + { + assert( eprop_hist_it != eprop_history_.end() ); + + L = eprop_hist_it->error_signal_; + + z_bar = V_.P_v_m_ * z_bar + V_.P_z_in_ * z; + grad += L * z_bar; + z = 0.0; // set spiking variable to 0 between spikes + + ++eprop_hist_it; + } + } + presyn_isis.clear(); + + const long learning_window = kernel().simulation_manager.get_eprop_learning_window().get_steps(); + if ( average_gradient ) + { + grad /= learning_window; + } + + return grad; +} + +} // namespace nest diff --git a/models/eprop_readout_bsshslm_2020.h b/models/eprop_readout_bsshslm_2020.h new file mode 100644 index 0000000000..ba25d07d36 --- /dev/null +++ b/models/eprop_readout_bsshslm_2020.h @@ -0,0 +1,525 @@ +/* + * eprop_readout_bsshslm_2020.h + * + * This file is part of NEST. + * + * Copyright (C) 2004 The NEST Initiative + * + * NEST is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 2 of the License, or + * (at your option) any later version. + * + * NEST is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with NEST. If not, see . + * + */ + +#ifndef EPROP_READOUT_BSSHSLM_2020_H +#define EPROP_READOUT_BSSHSLM_2020_H + +// nestkernel +#include "connection.h" +#include "eprop_archiving_node.h" +#include "eprop_archiving_node_impl.h" +#include "event.h" +#include "nest_types.h" +#include "ring_buffer.h" +#include "universal_data_logger.h" + +namespace nest +{ + +/* BeginUserDocs: neuron, e-prop plasticity, current-based + +Short description ++++++++++++++++++ + +Current-based leaky integrate readout neuron model with delta-shaped +postsynaptic currents for e-prop plasticity + +Description ++++++++++++ + +``eprop_readout_bsshslm_2020`` is an implementation of a integrate-and-fire neuron model +with delta-shaped postsynaptic currents used as readout neuron for eligibility propagation (e-prop) plasticity. + +E-prop plasticity was originally introduced and implemented in TensorFlow in [1]_. + +The suffix ``_bsshslm_2020`` follows the NEST convention to indicate in the +model name the paper that introduced it by the first letter of the authors' last +names and the publication year. + + +The membrane voltage time course :math:`v_j^t` of the neuron :math:`j` is given by: + +.. math:: + v_j^t &= \kappa v_j^{t-1}+\sum_{i \neq j}W_{ji}^\mathrm{out}z_i^{t-1} + -z_j^{t-1}v_\mathrm{th} \,, \\ + \kappa &= e^{-\frac{\Delta t}{\tau_\mathrm{m}}} \,, + +whereby :math:`W_{ji}^\mathrm{out}` are the output synaptic weights and +:math:`z_i^{t-1}` are the recurrent presynaptic spike state variables. + +Descriptions of further parameters and variables can be found in the table below. + +An additional state variable and the corresponding differential +equation represents a piecewise constant external current. + +See the documentation on the ``iaf_psc_delta`` neuron model for more information +on the integration of the subthreshold dynamics. + +The change of the synaptic weight is calculated from the gradient :math:`g` of +the loss :math:`E` with respect to the synaptic weight :math:`W_{ji}`: +The change of the synaptic weight is calculated from the gradient +:math:`\frac{\mathrm{d}{E}}{\mathrm{d}{W_{ij}}}=g` +which depends on the presynaptic +spikes :math:`z_i^{t-1}` and the learning signal :math:`L_j^t` emitted by the readout +neurons. + +.. math:: + \frac{\mathrm{d}E}{\mathrm{d}W_{ji}} = g &= \sum_t L_j^t \bar{z}_i^{t-1}\,. \\ + +The presynaptic spike trains are low-pass filtered with an exponential kernel: + +.. math:: + \bar{z}_i^t &=\mathcal{F}_\kappa(z_i^t)\,, \\ + \mathcal{F}_\kappa(z_i^t) &= \kappa\, \mathcal{F}_\kappa(z_i^{t-1}) + z_i^t + \;\text{with}\, \mathcal{F}_\kappa(z_i^0)=z_i^0\,\,. + +Since readout neurons are leaky integrators without a spiking mechanism, the +formula for computing the gradient lacks the surrogate gradient / +pseudo-derivative and a firing regularization term. + +For more information on e-prop plasticity, see the documentation on the other e-prop models: + + * :doc:`eprop_iaf_bsshslm_2020<../models/eprop_iaf_bsshslm_2020/>` + * :doc:`eprop_iaf_adapt_bsshslm_2020<../models/eprop_iaf_adapt_bsshslm_2020/>` + * :doc:`eprop_synapse_bsshslm_2020<../models/eprop_synapse_bsshslm_2020/>` + * :doc:`eprop_learning_signal_connection_bsshslm_2020<../models/eprop_learning_signal_connection_bsshslm_2020/>` + +Details on the event-based NEST implementation of e-prop can be found in [2]_. + +Parameters +++++++++++ + +The following parameters can be set in the status dictionary. + +===================== ======= ===================== ================== =============================================== +**Neuron parameters** +---------------------------------------------------------------------------------------------------------------------- +Parameter Unit Math equivalent Default Description +===================== ======= ===================== ================== =============================================== +C_m pF :math:`C_\text{m}` 250.0 Capacitance of the membrane +E_L mV :math:`E_\text{L}` 0.0 Leak / resting membrane potential +I_e pA :math:`I_\text{e}` 0.0 Constant external input current +loss :math:`E` mean_squared_error Loss function + ["mean_squared_error", "cross_entropy"] +regular_spike_arrival Boolean True If True, the input spikes arrive at the + end of the time step, if False at the + beginning (determines PSC scale) +tau_m ms :math:`\tau_\text{m}` 10.0 Time constant of the membrane +V_min mV :math:`v_\text{min}` -1.79e+308 Absolute lower bound of the membrane voltage +===================== ======= ===================== ================== =============================================== + +The following state variables evolve during simulation. + +===================== ==== =============== ============= ========================== +**Neuron state variables and recordables** +----------------------------------------------------------------------------------- +State variable Unit Math equivalent Initial value Description +===================== ==== =============== ============= ========================== +error_signal mV :math:`L_j` 0.0 Error signal +readout_signal mV :math:`y_j` 0.0 Readout signal +readout_signal_unnorm mV 0.0 Unnormalized readout signal +target_signal mV :math:`y^*_j` 0.0 Target signal +V_m mV :math:`v_j` 0.0 Membrane voltage +===================== ==== =============== ============= ========================== + +Recordables ++++++++++++ + +The following variables can be recorded: + + - error signal ``error_signal`` + - readout signal ``readout_signal`` + - readout signal ``readout_signal_unnorm`` + - target signal ``target_signal`` + - membrane potential ``V_m`` + +Usage ++++++ + +This model can only be used in combination with the other e-prop models, +whereby the network architecture requires specific wiring, input, and output. +The usage is demonstrated in several +:doc:`supervised regression and classification tasks <../auto_examples/eprop_plasticity/index>` +reproducing among others the original proof-of-concept tasks in [1]_. + +References +++++++++++ + +.. [1] Bellec G, Scherr F, Subramoney F, Hajek E, Salaj D, Legenstein R, + Maass W (2020). A solution to the learning dilemma for recurrent + networks of spiking neurons. Nature Communications, 11:3625. + https://doi.org/10.1038/s41467-020-17236-y +.. [2] Korcsak-Gorzo A, Stapmanns J, Espinoza Valverde JA, Dahmen D, + van Albada SJ, Bolten M, Diesmann M. Event-based implementation of + eligibility propagation (in preparation) + +Sends +++++++++ + +LearningSignalConnectionEvent, DelayedRateConnectionEvent + +Receives +++++++++ + +SpikeEvent, CurrentEvent, DelayedRateConnectionEvent, DataLoggingRequest + +See also +++++++++ + +Examples using this model +++++++++++++++++++++++++++ + +.. listexamples:: eprop_readout_bsshslm_2020 + +EndUserDocs */ + +void register_eprop_readout_bsshslm_2020( const std::string& name ); + +/** + * Class implementing a current-based leaky integrate readout neuron model with delta-shaped postsynaptic currents for + * e-prop plasticity according to Bellec et al. (2020). + */ +class eprop_readout_bsshslm_2020 : public EpropArchivingNodeReadout +{ + +public: + //! Default constructor. + eprop_readout_bsshslm_2020(); + + //! Copy constructor. + eprop_readout_bsshslm_2020( const eprop_readout_bsshslm_2020& ); + + using Node::handle; + using Node::handles_test_event; + + using Node::sends_secondary_event; + + void + sends_secondary_event( LearningSignalConnectionEvent& ) override + { + } + + void + sends_secondary_event( DelayedRateConnectionEvent& ) override + { + } + + void handle( SpikeEvent& ) override; + void handle( CurrentEvent& ) override; + void handle( DelayedRateConnectionEvent& ) override; + void handle( DataLoggingRequest& ) override; + + size_t handles_test_event( SpikeEvent&, size_t ) override; + size_t handles_test_event( CurrentEvent&, size_t ) override; + size_t handles_test_event( DelayedRateConnectionEvent&, size_t ) override; + size_t handles_test_event( DataLoggingRequest&, size_t ) override; + + void get_status( DictionaryDatum& ) const override; + void set_status( const DictionaryDatum& ) override; + + double compute_gradient( std::vector< long >& presyn_isis, + const long t_previous_update, + const long t_previous_trigger_spike, + const double kappa, + const bool average_gradient ) override; + + void pre_run_hook() override; + long get_shift() const override; + bool is_eprop_recurrent_node() const override; + void update( Time const&, const long, const long ) override; + +protected: + void init_buffers_() override; + +private: + //! Compute the error signal based on the mean-squared error loss. + void compute_error_signal_mean_squared_error( const long lag ); + + //! Compute the error signal based on the cross-entropy loss. + void compute_error_signal_cross_entropy( const long lag ); + + //! Compute the error signal based on a loss function. + void ( eprop_readout_bsshslm_2020::*compute_error_signal )( const long lag ); + + //! Map for storing a static set of recordables. + friend class RecordablesMap< eprop_readout_bsshslm_2020 >; + + //! Logger for universal data supporting the data logging request / reply mechanism. Populated with a recordables map. + friend class UniversalDataLogger< eprop_readout_bsshslm_2020 >; + + //! Structure of parameters. + struct Parameters_ + { + //! Capacitance of the membrane (pF). + double C_m_; + + //! Leak / resting membrane potential (mV). + double E_L_; + + //! Constant external input current (pA). + double I_e_; + + //! Loss function ["mean_squared_error", "cross_entropy"]. + std::string loss_; + + //! If True, the input spikes arrive at the beginning of the time step, if False at the end (determines PSC scale). + bool regular_spike_arrival_; + + //! Time constant of the membrane (ms). + double tau_m_; + + //! Absolute lower bound of the membrane voltage relative to the leak membrane potential (mV). + double V_min_; + + //! Default constructor. + Parameters_(); + + //! Get the parameters and their values. + void get( DictionaryDatum& ) const; + + //! Set the parameters and throw errors in case of invalid values. + double set( const DictionaryDatum&, Node* ); + }; + + //! Structure of state variables. + struct State_ + { + //! Error signal. Deviation between the readout and the target signal. + double error_signal_; + + //! Readout signal. Leaky integrated spikes emitted by the recurrent network. + double readout_signal_; + + //! Unnormalized readout signal. Readout signal not yet divided by the readout signals of other readout neurons. + double readout_signal_unnorm_; + + //! Target / teacher signal that the network is supposed to learn. + double target_signal_; + + //! Input current (pA). + double i_in_; + + //! Membrane voltage relative to the leak membrane potential (mV). + double v_m_; + + //! Binary input spike variables - 1.0 if the neuron has spiked in the previous time step and 0.0 otherwise. + double z_in_; + + //! Default constructor. + State_(); + + //! Get the state variables and their values. + void get( DictionaryDatum&, const Parameters_& ) const; + + //! Set the state variables. + void set( const DictionaryDatum&, const Parameters_&, double, Node* ); + }; + + //! Structure of buffers. + struct Buffers_ + { + //! Default constructor. + Buffers_( eprop_readout_bsshslm_2020& ); + + //! Copy constructor. + Buffers_( const Buffers_&, eprop_readout_bsshslm_2020& ); + + //! Normalization rate of the readout signal. Sum of the readout signals of all readout neurons. + double normalization_rate_; + + //! Buffer for incoming spikes. + RingBuffer spikes_; + + //! Buffer for incoming currents. + RingBuffer currents_; + + //! Logger for universal data. + UniversalDataLogger< eprop_readout_bsshslm_2020 > logger_; + }; + + //! Structure of general variables. + struct Variables_ + { + //! Propagator matrix entry for evolving the membrane voltage. + double P_v_m_; + + //! Propagator matrix entry for evolving the incoming spike variables. + double P_z_in_; + + //! Propagator matrix entry for evolving the incoming currents. + double P_i_in_; + + //! If the loss requires communication between the readout neurons and thus a buffer for the exchanged signals. + bool signal_to_other_readouts_; + }; + + //! Minimal spike receptor type. Start with 1 to forbid port 0 and avoid accidental creation of connections with no + //! receptor type set. + static const size_t MIN_RATE_RECEPTOR = 1; + + //! Enumeration of spike receptor types. + enum RateSynapseTypes + { + READOUT_SIG = MIN_RATE_RECEPTOR, + TARGET_SIG, + SUP_RATE_RECEPTOR + }; + + //! Get the current value of the membrane voltage. + double + get_v_m_() const + { + return S_.v_m_ + P_.E_L_; + } + + //! Get the current value of the normalized readout signal. + double + get_readout_signal_() const + { + return S_.readout_signal_; + } + + //! Get the current value of the unnormalized readout signal. + double + get_readout_signal_unnorm_() const + { + return S_.readout_signal_unnorm_; + } + + //! Get the current value of the target signal. + double + get_target_signal_() const + { + return S_.target_signal_; + } + + //! Get the current value of the error signal. + double + get_error_signal_() const + { + return S_.error_signal_; + } + + // the order in which the structure instances are defined is important for speed + + //!< Structure of parameters. + Parameters_ P_; + + //!< Structure of state variables. + State_ S_; + + //!< Structure of general variables. + Variables_ V_; + + //!< Structure of buffers. + Buffers_ B_; + + //! Map storing a static set of recordables. + static RecordablesMap< eprop_readout_bsshslm_2020 > recordablesMap_; +}; + +inline size_t +eprop_readout_bsshslm_2020::handles_test_event( SpikeEvent&, size_t receptor_type ) +{ + if ( receptor_type != 0 ) + { + throw UnknownReceptorType( receptor_type, get_name() ); + } + + return 0; +} + +inline size_t +eprop_readout_bsshslm_2020::handles_test_event( CurrentEvent&, size_t receptor_type ) +{ + if ( receptor_type != 0 ) + { + throw UnknownReceptorType( receptor_type, get_name() ); + } + + return 0; +} + +inline size_t +eprop_readout_bsshslm_2020::handles_test_event( DelayedRateConnectionEvent& e, size_t receptor_type ) +{ + size_t step_rate_model_id = kernel().model_manager.get_node_model_id( "step_rate_generator" ); + size_t model_id = e.get_sender().get_model_id(); + + if ( step_rate_model_id == model_id and receptor_type != TARGET_SIG ) + { + throw IllegalConnection( + "eprop_readout_bsshslm_2020 neurons expect a connection with a step_rate_generator node through receptor_type " + "2." ); + } + + if ( receptor_type < MIN_RATE_RECEPTOR or receptor_type >= SUP_RATE_RECEPTOR ) + { + throw UnknownReceptorType( receptor_type, get_name() ); + } + + return receptor_type; +} + +inline size_t +eprop_readout_bsshslm_2020::handles_test_event( DataLoggingRequest& dlr, size_t receptor_type ) +{ + if ( receptor_type != 0 ) + { + throw UnknownReceptorType( receptor_type, get_name() ); + } + + return B_.logger_.connect_logging_device( dlr, recordablesMap_ ); +} + +inline void +eprop_readout_bsshslm_2020::get_status( DictionaryDatum& d ) const +{ + P_.get( d ); + S_.get( d, P_ ); + ( *d )[ names::recordables ] = recordablesMap_.get_list(); + + DictionaryDatum receptor_dict_ = new Dictionary(); + ( *receptor_dict_ )[ names::readout_signal ] = READOUT_SIG; + ( *receptor_dict_ )[ names::target_signal ] = TARGET_SIG; + + ( *d )[ names::receptor_types ] = receptor_dict_; +} + +inline void +eprop_readout_bsshslm_2020::set_status( const DictionaryDatum& d ) +{ + // temporary copies in case of errors + Parameters_ ptmp = P_; + State_ stmp = S_; + + // make sure that ptmp and stmp consistent - throw BadProperty if not + const double delta_EL = ptmp.set( d, this ); + stmp.set( d, ptmp, delta_EL, this ); + + P_ = ptmp; + S_ = stmp; +} + +} // namespace nest + +#endif // EPROP_READOUT_BSSHSLM_2020_H diff --git a/models/eprop_synapse_bsshslm_2020.cpp b/models/eprop_synapse_bsshslm_2020.cpp new file mode 100644 index 0000000000..ceb1dba4d1 --- /dev/null +++ b/models/eprop_synapse_bsshslm_2020.cpp @@ -0,0 +1,150 @@ +/* + * eprop_synapse_bsshslm_2020.cpp + * + * This file is part of NEST. + * + * Copyright (C) 2004 The NEST Initiative + * + * NEST is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 2 of the License, or + * (at your option) any later version. + * + * NEST is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with NEST. If not, see . + * + */ + +#include "eprop_synapse_bsshslm_2020.h" + +// nestkernel +#include "nest_impl.h" + +namespace nest +{ + +void +register_eprop_synapse_bsshslm_2020( const std::string& name ) +{ + register_connection_model< eprop_synapse_bsshslm_2020 >( name ); +} + +EpropSynapseBSSHSLM2020CommonProperties::EpropSynapseBSSHSLM2020CommonProperties() + : CommonSynapseProperties() + , average_gradient_( false ) + , optimizer_cp_( new WeightOptimizerCommonPropertiesGradientDescent() ) +{ +} + +EpropSynapseBSSHSLM2020CommonProperties::EpropSynapseBSSHSLM2020CommonProperties( + const EpropSynapseBSSHSLM2020CommonProperties& cp ) + : CommonSynapseProperties( cp ) + , average_gradient_( cp.average_gradient_ ) + , optimizer_cp_( cp.optimizer_cp_->clone() ) +{ +} + +EpropSynapseBSSHSLM2020CommonProperties::~EpropSynapseBSSHSLM2020CommonProperties() +{ + delete optimizer_cp_; +} + +void +EpropSynapseBSSHSLM2020CommonProperties::get_status( DictionaryDatum& d ) const +{ + CommonSynapseProperties::get_status( d ); + def< bool >( d, names::average_gradient, average_gradient_ ); + def< std::string >( d, names::optimizer, optimizer_cp_->get_name() ); + DictionaryDatum optimizer_dict = new Dictionary; + optimizer_cp_->get_status( optimizer_dict ); + ( *d )[ names::optimizer ] = optimizer_dict; +} + +void +EpropSynapseBSSHSLM2020CommonProperties::set_status( const DictionaryDatum& d, ConnectorModel& cm ) +{ + CommonSynapseProperties::set_status( d, cm ); + updateValue< bool >( d, names::average_gradient, average_gradient_ ); + + if ( d->known( names::optimizer ) ) + { + DictionaryDatum optimizer_dict = getValue< DictionaryDatum >( d->lookup( names::optimizer ) ); + + std::string new_optimizer; + const bool set_optimizer = updateValue< std::string >( optimizer_dict, names::type, new_optimizer ); + if ( set_optimizer and new_optimizer != optimizer_cp_->get_name() ) + { + if ( kernel().connection_manager.get_num_connections( cm.get_syn_id() ) > 0 ) + { + throw BadParameter( "The optimizer cannot be changed because synapses have been created." ); + } + + // TODO: selection here should be based on an optimizer registry and a factory + // delete is in if/else if because we must delete only when we are sure that we have a valid optimizer + if ( new_optimizer == "gradient_descent" ) + { + delete optimizer_cp_; + optimizer_cp_ = new WeightOptimizerCommonPropertiesGradientDescent(); + } + else if ( new_optimizer == "adam" ) + { + delete optimizer_cp_; + optimizer_cp_ = new WeightOptimizerCommonPropertiesAdam(); + } + else + { + throw BadProperty( "optimizer from [\"gradient_descent\", \"adam\"] required." ); + } + } + + // we can now set the defaults on the new optimizer common properties + optimizer_cp_->set_status( optimizer_dict ); + } +} + +template <> +void +Connector< eprop_synapse_bsshslm_2020< TargetIdentifierPtrRport > >::disable_connection( const size_t lcid ) +{ + assert( not C_[ lcid ].is_disabled() ); + C_[ lcid ].disable(); + C_[ lcid ].delete_optimizer(); +} + +template <> +void +Connector< eprop_synapse_bsshslm_2020< TargetIdentifierIndex > >::disable_connection( const size_t lcid ) +{ + assert( not C_[ lcid ].is_disabled() ); + C_[ lcid ].disable(); + C_[ lcid ].delete_optimizer(); +} + + +template <> +Connector< eprop_synapse_bsshslm_2020< TargetIdentifierPtrRport > >::~Connector() +{ + for ( auto& c : C_ ) + { + c.delete_optimizer(); + } + C_.clear(); +} + +template <> +Connector< eprop_synapse_bsshslm_2020< TargetIdentifierIndex > >::~Connector() +{ + for ( auto& c : C_ ) + { + c.delete_optimizer(); + } + C_.clear(); +} + + +} // namespace nest diff --git a/models/eprop_synapse_bsshslm_2020.h b/models/eprop_synapse_bsshslm_2020.h new file mode 100644 index 0000000000..52223ca67b --- /dev/null +++ b/models/eprop_synapse_bsshslm_2020.h @@ -0,0 +1,618 @@ +/* + * eprop_synapse_bsshslm_2020.h + * + * This file is part of NEST. + * + * Copyright (C) 2004 The NEST Initiative + * + * NEST is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 2 of the License, or + * (at your option) any later version. + * + * NEST is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with NEST. If not, see . + * + */ + +#ifndef EPROP_SYNAPSE_BSSHSLM_2020_H +#define EPROP_SYNAPSE_BSSHSLM_2020_H + +// nestkernel +#include "connection.h" +#include "connector_base.h" +#include "eprop_archiving_node.h" +#include "target_identifier.h" +#include "weight_optimizer.h" + +namespace nest +{ + +/* BeginUserDocs: synapse, e-prop plasticity + +Short description ++++++++++++++++++ + +Synapse type for e-prop plasticity + +Description ++++++++++++ + +``eprop_synapse_bsshslm_2020`` is an implementation of a connector model to create synapses between postsynaptic +neurons :math:`j` and presynaptic neurons :math:`i` for eligibility propagation (e-prop) plasticity. + +E-prop plasticity was originally introduced and implemented in TensorFlow in [1]_. + +The suffix ``_bsshslm_2020`` follows the NEST convention to indicate in the +model name the paper that introduced it by the first letter of the authors' last +names and the publication year. + +The e-prop synapse collects the presynaptic spikes needed for calculating the +weight update. When it is time to update, it triggers the calculation of the +gradient which is specific to the post-synaptic neuron and is thus defined there. + +Eventually, it optimizes the weight with the specified optimizer. + +E-prop synapses require archiving of continuous quantities. Therefore e-prop +synapses can only be connected to neuron models that are capable of +archiving. So far, compatible models are ``eprop_iaf_bsshslm_2020``, +``eprop_iaf_adapt_bsshslm_2020``, and ``eprop_readout_bsshslm_2020``. + +For more information on e-prop plasticity, see the documentation on the other e-prop models: + + * :doc:`eprop_iaf_bsshslm_2020<../models/eprop_iaf_bsshslm_2020/>` + * :doc:`eprop_iaf_adapt_bsshslm_2020<../models/eprop_iaf_adapt_bsshslm_2020/>` + * :doc:`eprop_readout_bsshslm_2020<../models/eprop_readout_bsshslm_2020/>` + * :doc:`eprop_learning_signal_connection_bsshslm_2020<../models/eprop_learning_signal_connection_bsshslm_2020/>` + +For more information on the optimizers, see the documentation of the weight optimizer: + + * :doc:`weight_optimizer<../models/weight_optimizer/>` + +Details on the event-based NEST implementation of e-prop can be found in [2]_. + +.. warning:: + + This synaptic plasticity rule does not take + :ref:`precise spike timing ` into + account. When calculating the weight update, the precise spike time part + of the timestamp is ignored. + +Parameters +++++++++++ + +The following parameters can be set in the status dictionary. + +================ ======= =============== ======= ====================================================== +**Common synapse parameters** +------------------------------------------------------------------------------------------------------- +Parameter Unit Math equivalent Default Description +================ ======= =============== ======= ====================================================== +average_gradient Boolean False If True, average the gradient over the learning window +optimizer {} Dictionary of optimizer parameters +================ ======= =============== ======= ====================================================== + +============= ==== ========================= ======= ========================================================= +**Individual synapse parameters** +-------------------------------------------------------------------------------------------------------------- +Parameter Unit Math equivalent Default Description +============= ==== ========================= ======= ========================================================= +delay ms :math:`d_{ji}` 1.0 Dendritic delay +tau_m_readout ms :math:`\tau_\text{m,out}` 10.0 Time constant for low-pass filtering of eligibility trace +weight pA :math:`W_{ji}` 1.0 Initial value of synaptic weight +============= ==== ========================= ======= ========================================================= + +Recordables ++++++++++++ + +The following variables can be recorded. + + - synaptic weight ``weight`` + +Usage ++++++ + +This model can only be used in combination with the other e-prop models, +whereby the network architecture requires specific wiring, input, and output. +The usage is demonstrated in several +:doc:`supervised regression and classification tasks <../auto_examples/eprop_plasticity/index>` +reproducing among others the original proof-of-concept tasks in [1]_. + +Transmits ++++++++++ + +SpikeEvent, DSSpikeEvent + +References +++++++++++ + +.. [1] Bellec G, Scherr F, Subramoney F, Hajek E, Salaj D, Legenstein R, + Maass W (2020). A solution to the learning dilemma for recurrent + networks of spiking neurons. Nature Communications, 11:3625. + https://doi.org/10.1038/s41467-020-17236-y + +.. [2] Korcsak-Gorzo A, Stapmanns J, Espinoza Valverde JA, Dahmen D, + van Albada SJ, Bolten M, Diesmann M. Event-based implementation of + eligibility propagation (in preparation) + +See also +++++++++ + +Examples using this model +++++++++++++++++++++++++++ + +.. listexamples:: eprop_synapse_bsshslm_2020 + +EndUserDocs */ + +/** + * Base class implementing common properties for the e-prop synapse model. + * + * This class in particular manages a pointer to weight-optimizer common properties to support + * exchanging the weight optimizer at runtime. Setting the weight-optimizer common properties + * determines the WO type. It can only be exchanged as long as no synapses for the model exist. + * The WO CP object is responsible for providing individual optimizer objects to synapses upon + * connection. + * + * @see WeightOptimizerCommonProperties + */ +class EpropSynapseBSSHSLM2020CommonProperties : public CommonSynapseProperties +{ +public: + // Default constructor. + EpropSynapseBSSHSLM2020CommonProperties(); + + //! Copy constructor. + EpropSynapseBSSHSLM2020CommonProperties( const EpropSynapseBSSHSLM2020CommonProperties& ); + + //! Assignment operator. + EpropSynapseBSSHSLM2020CommonProperties& operator=( const EpropSynapseBSSHSLM2020CommonProperties& ) = delete; + + //! Destructor. + ~EpropSynapseBSSHSLM2020CommonProperties(); + + //! Get parameter dictionary. + void get_status( DictionaryDatum& d ) const; + + //! Update values in parameter dictionary. + void set_status( const DictionaryDatum& d, ConnectorModel& cm ); + + //! If True, average the gradient over the learning window. + bool average_gradient_; + + /** + * Pointer to common properties object for weight optimizer. + * + * @note Must only be changed as long as no synapses of the model exist. + */ + WeightOptimizerCommonProperties* optimizer_cp_; +}; + +//! Register the eprop synapse model. +void register_eprop_synapse_bsshslm_2020( const std::string& name ); + +/** + * Class implementing a synapse model for e-prop plasticity according to Bellec et al. (2020). + * + * @note Several aspects of this synapse are in place to reproduce the Tensorflow implementation of Bellec et al (2020). + * + * @note Each synapse has a optimizer_ object managed through a `WeightOptimizer*`, pointing to an object of + * a specific weight optimizer type. This optimizer, drawing also on parameters in the `WeightOptimizerCommonProperties` + * accessible via the synapse models `CommonProperties::optimizer_cp_` pointer, computes the weight update for the + * neuron. The actual optimizer type can be selected at runtime (before creating any synapses) by exchanging the + * `optimizer_cp_` pointer. Individual optimizer objects are created by `check_connection()` when a synapse is actually + * created. It is important that the constructors of `eprop_synapse_bsshslm_2020` **do not** create optimizer objects + * and that the destructor **does not** delete optimizer objects; this currently leads to bugs when using Boosts's + * `spreadsort()` due to use of the copy constructor where it should suffice to use the move constructor. Therefore, + * `check_connection()`creates the optimizer object when it is needed and specializations of `Connector::~Connector()` + * and `Connector::disable_connection()` delete it by calling `delete_optimizer()`. A disadvantage of this approach is + * that the `default_connection` in the connector model does not have an optimizer object, whence it is not possible to + * set default (initial) values for the per-synapse optimizer. + * + * @note If we can find a way to modify our co-sorting of source and target tables in Boost's `spreadsort()` to only use + * move operations, it should be possible to create the individual optimizers in the copy constructor of + * `eprop_synapse_bsshslm_2020` and to delete it in the destructor. The `default_connection` can then own an optimizer + * and default values could be set on it. + */ +template < typename targetidentifierT > +class eprop_synapse_bsshslm_2020 : public Connection< targetidentifierT > +{ + +public: + //! Type of the common synapse properties. + typedef EpropSynapseBSSHSLM2020CommonProperties CommonPropertiesType; + + //! Type of the connection base. + typedef Connection< targetidentifierT > ConnectionBase; + + /** + * Properties of the connection model. + * + * @note Does not support LBL at present because we cannot properly cast GenericModel common props in that case. + */ + static constexpr ConnectionModelProperties properties = ConnectionModelProperties::HAS_DELAY + | ConnectionModelProperties::IS_PRIMARY | ConnectionModelProperties::REQUIRES_EPROP_ARCHIVING + | ConnectionModelProperties::SUPPORTS_HPC; + + //! Default constructor. + eprop_synapse_bsshslm_2020(); + + //! Destructor + ~eprop_synapse_bsshslm_2020(); + + //! Parameterized copy constructor. + eprop_synapse_bsshslm_2020( const eprop_synapse_bsshslm_2020& ); + + //! Assignment operator + eprop_synapse_bsshslm_2020& operator=( const eprop_synapse_bsshslm_2020& ); + + //! Move constructor + eprop_synapse_bsshslm_2020( eprop_synapse_bsshslm_2020&& ); + + //! Move assignment operator + eprop_synapse_bsshslm_2020& operator=( eprop_synapse_bsshslm_2020&& ); + + using ConnectionBase::get_delay; + using ConnectionBase::get_delay_steps; + using ConnectionBase::get_rport; + using ConnectionBase::get_target; + + //! Get parameter dictionary. + void get_status( DictionaryDatum& d ) const; + + //! Update values in parameter dictionary. + void set_status( const DictionaryDatum& d, ConnectorModel& cm ); + + //! Send the spike event. + bool send( Event& e, size_t thread, const EpropSynapseBSSHSLM2020CommonProperties& cp ); + + //! Dummy node for testing the connection. + class ConnTestDummyNode : public ConnTestDummyNodeBase + { + public: + using ConnTestDummyNodeBase::handles_test_event; + + size_t + handles_test_event( SpikeEvent&, size_t ) + { + return invalid_port; + } + + size_t + handles_test_event( DSSpikeEvent&, size_t ) + { + return invalid_port; + } + }; + + /** + * Check if the target accepts the event and receptor type requested by the sender. + * + * @note This sets the optimizer_ member. + */ + void check_connection( Node& s, Node& t, size_t receptor_type, const CommonPropertiesType& cp ); + + //! Set the synaptic weight to the provided value. + void + set_weight( const double w ) + { + weight_ = w; + } + + //! Delete optimizer + void delete_optimizer(); + +private: + //! Synaptic weight. + double weight_; + + //! The time step when the previous spike arrived. + long t_previous_spike_; + + //! The time step when the previous e-prop update was. + long t_previous_update_; + + //! The time step when the next e-prop update will be. + long t_next_update_; + + //! The time step when the spike arrived that triggered the previous e-prop update. + long t_previous_trigger_spike_; + + //! %Time constant for low-pass filtering the eligibility trace. + double tau_m_readout_; + + //! Low-pass filter of the eligibility trace. + double kappa_; + + //! If this connection is between two recurrent neurons. + bool is_recurrent_to_recurrent_conn_; + + //! Vector of presynaptic inter-spike-intervals. + std::vector< long > presyn_isis_; + + /** + * Optimizer + * + * @note Pointer is set by check_connection() and deleted by delete_optimizer(). + */ + WeightOptimizer* optimizer_; +}; + +template < typename targetidentifierT > +constexpr ConnectionModelProperties eprop_synapse_bsshslm_2020< targetidentifierT >::properties; + +// Explicitly declare specializations of Connector methods that need to do special things for eprop_synapse_bsshslm_2020 +template <> +void Connector< eprop_synapse_bsshslm_2020< TargetIdentifierPtrRport > >::disable_connection( const size_t lcid ); + +template <> +void Connector< eprop_synapse_bsshslm_2020< TargetIdentifierIndex > >::disable_connection( const size_t lcid ); + +template <> +Connector< eprop_synapse_bsshslm_2020< TargetIdentifierPtrRport > >::~Connector(); + +template <> +Connector< eprop_synapse_bsshslm_2020< TargetIdentifierIndex > >::~Connector(); + + +template < typename targetidentifierT > +eprop_synapse_bsshslm_2020< targetidentifierT >::eprop_synapse_bsshslm_2020() + : ConnectionBase() + , weight_( 1.0 ) + , t_previous_spike_( 0 ) + , t_previous_update_( 0 ) + , t_next_update_( 0 ) + , t_previous_trigger_spike_( 0 ) + , tau_m_readout_( 10.0 ) + , kappa_( std::exp( -Time::get_resolution().get_ms() / tau_m_readout_ ) ) + , is_recurrent_to_recurrent_conn_( false ) + , optimizer_( nullptr ) +{ +} + +template < typename targetidentifierT > +eprop_synapse_bsshslm_2020< targetidentifierT >::~eprop_synapse_bsshslm_2020() +{ +} + +// This copy constructor is used to create instances from prototypes. +// Therefore, only parameter values are copied. +template < typename targetidentifierT > +eprop_synapse_bsshslm_2020< targetidentifierT >::eprop_synapse_bsshslm_2020( const eprop_synapse_bsshslm_2020& es ) + : ConnectionBase( es ) + , weight_( es.weight_ ) + , t_previous_spike_( 0 ) + , t_previous_update_( 0 ) + , t_next_update_( kernel().simulation_manager.get_eprop_update_interval().get_steps() ) + , t_previous_trigger_spike_( 0 ) + , tau_m_readout_( es.tau_m_readout_ ) + , kappa_( std::exp( -Time::get_resolution().get_ms() / tau_m_readout_ ) ) + , is_recurrent_to_recurrent_conn_( es.is_recurrent_to_recurrent_conn_ ) + , optimizer_( es.optimizer_ ) +{ +} + +// This assignment operator is used to write a connection into the connection array. +template < typename targetidentifierT > +eprop_synapse_bsshslm_2020< targetidentifierT >& +eprop_synapse_bsshslm_2020< targetidentifierT >::operator=( const eprop_synapse_bsshslm_2020& es ) +{ + if ( this == &es ) + { + return *this; + } + + ConnectionBase::operator=( es ); + + weight_ = es.weight_; + t_previous_spike_ = es.t_previous_spike_; + t_previous_update_ = es.t_previous_update_; + t_next_update_ = es.t_next_update_; + t_previous_trigger_spike_ = es.t_previous_trigger_spike_; + tau_m_readout_ = es.tau_m_readout_; + kappa_ = es.kappa_; + is_recurrent_to_recurrent_conn_ = es.is_recurrent_to_recurrent_conn_; + optimizer_ = es.optimizer_; + + return *this; +} + +template < typename targetidentifierT > +eprop_synapse_bsshslm_2020< targetidentifierT >::eprop_synapse_bsshslm_2020( eprop_synapse_bsshslm_2020&& es ) + : ConnectionBase( es ) + , weight_( es.weight_ ) + , t_previous_spike_( 0 ) + , t_previous_update_( 0 ) + , t_next_update_( es.t_next_update_ ) + , t_previous_trigger_spike_( 0 ) + , tau_m_readout_( es.tau_m_readout_ ) + , kappa_( es.kappa_ ) + , is_recurrent_to_recurrent_conn_( es.is_recurrent_to_recurrent_conn_ ) + , optimizer_( es.optimizer_ ) +{ + es.optimizer_ = nullptr; +} + +// This assignment operator is used to write a connection into the connection array. +template < typename targetidentifierT > +eprop_synapse_bsshslm_2020< targetidentifierT >& +eprop_synapse_bsshslm_2020< targetidentifierT >::operator=( eprop_synapse_bsshslm_2020&& es ) +{ + if ( this == &es ) + { + return *this; + } + + ConnectionBase::operator=( es ); + + weight_ = es.weight_; + t_previous_spike_ = es.t_previous_spike_; + t_previous_update_ = es.t_previous_update_; + t_next_update_ = es.t_next_update_; + t_previous_trigger_spike_ = es.t_previous_trigger_spike_; + tau_m_readout_ = es.tau_m_readout_; + kappa_ = es.kappa_; + is_recurrent_to_recurrent_conn_ = es.is_recurrent_to_recurrent_conn_; + + optimizer_ = es.optimizer_; + es.optimizer_ = nullptr; + + return *this; +} + +template < typename targetidentifierT > +inline void +eprop_synapse_bsshslm_2020< targetidentifierT >::check_connection( Node& s, + Node& t, + size_t receptor_type, + const CommonPropertiesType& cp ) +{ + // When we get here, delay has been set so we can check it. + if ( get_delay_steps() != 1 ) + { + throw IllegalConnection( "eprop synapses currently require a delay of one simulation step" ); + } + + ConnTestDummyNode dummy_target; + ConnectionBase::check_connection_( dummy_target, s, t, receptor_type ); + + t.register_eprop_connection(); + + optimizer_ = cp.optimizer_cp_->get_optimizer(); +} + +template < typename targetidentifierT > +inline void +eprop_synapse_bsshslm_2020< targetidentifierT >::delete_optimizer() +{ + delete optimizer_; + // do not set to nullptr to allow detection of double deletion +} + +template < typename targetidentifierT > +bool +eprop_synapse_bsshslm_2020< targetidentifierT >::send( Event& e, + size_t thread, + const EpropSynapseBSSHSLM2020CommonProperties& cp ) +{ + Node* target = get_target( thread ); + assert( target ); + + const long t_spike = e.get_stamp().get_steps(); + const long update_interval = kernel().simulation_manager.get_eprop_update_interval().get_steps(); + const long shift = target->get_shift(); + + const long interval_step = ( t_spike - shift ) % update_interval; + + if ( target->is_eprop_recurrent_node() and interval_step == 0 ) + { + return false; + } + + if ( t_previous_trigger_spike_ == 0 ) + { + t_previous_trigger_spike_ = t_spike; + } + + if ( t_previous_spike_ > 0 ) + { + const long t = t_spike >= t_next_update_ + shift ? t_next_update_ + shift : t_spike; + presyn_isis_.push_back( t - t_previous_spike_ ); + } + + if ( t_spike > t_next_update_ + shift ) + { + const long idx_current_update = ( t_spike - shift ) / update_interval; + const long t_current_update = idx_current_update * update_interval; + + target->write_update_to_history( t_previous_update_, t_current_update ); + + const double gradient = target->compute_gradient( + presyn_isis_, t_previous_update_, t_previous_trigger_spike_, kappa_, cp.average_gradient_ ); + + weight_ = optimizer_->optimized_weight( *cp.optimizer_cp_, idx_current_update, gradient, weight_ ); + + t_previous_update_ = t_current_update; + t_next_update_ = t_current_update + update_interval; + + t_previous_trigger_spike_ = t_spike; + } + + t_previous_spike_ = t_spike; + + e.set_receiver( *target ); + e.set_weight( weight_ ); + e.set_delay_steps( get_delay_steps() ); + e.set_rport( get_rport() ); + e(); + + return true; +} + +template < typename targetidentifierT > +void +eprop_synapse_bsshslm_2020< targetidentifierT >::get_status( DictionaryDatum& d ) const +{ + ConnectionBase::get_status( d ); + def< double >( d, names::weight, weight_ ); + def< double >( d, names::tau_m_readout, tau_m_readout_ ); + def< long >( d, names::size_of, sizeof( *this ) ); + + DictionaryDatum optimizer_dict = new Dictionary(); + + // The default_connection_ has no optimizer, therefore we need to protect it + if ( optimizer_ ) + { + optimizer_->get_status( optimizer_dict ); + ( *d )[ names::optimizer ] = optimizer_dict; + } +} + +template < typename targetidentifierT > +void +eprop_synapse_bsshslm_2020< targetidentifierT >::set_status( const DictionaryDatum& d, ConnectorModel& cm ) +{ + ConnectionBase::set_status( d, cm ); + if ( d->known( names::optimizer ) ) + { + // We must pass here if called by SetDefaults. In that case, the user will get and error + // message because the parameters for the synapse-specific optimizer have not been accessed. + if ( optimizer_ ) + { + optimizer_->set_status( getValue< DictionaryDatum >( d->lookup( names::optimizer ) ) ); + } + } + + updateValue< double >( d, names::weight, weight_ ); + + if ( updateValue< double >( d, names::tau_m_readout, tau_m_readout_ ) ) + { + if ( tau_m_readout_ <= 0 ) + { + throw BadProperty( "Membrane time constant of readout neuron tau_m_readout > 0 required." ); + } + kappa_ = std::exp( -Time::get_resolution().get_ms() / tau_m_readout_ ); + } + + const auto& gcm = + dynamic_cast< const GenericConnectorModel< eprop_synapse_bsshslm_2020< targetidentifierT > >& >( cm ); + const CommonPropertiesType& epcp = gcm.get_common_properties(); + if ( weight_ < epcp.optimizer_cp_->get_Wmin() ) + { + throw BadProperty( "Minimal weight Wmin ≤ weight required." ); + } + + if ( weight_ > epcp.optimizer_cp_->get_Wmax() ) + { + throw BadProperty( "weight ≤ maximal weight Wmax required." ); + } +} + +} // namespace nest + +#endif // EPROP_SYNAPSE_BSSHSLM_2020_H diff --git a/models/weight_optimizer.cpp b/models/weight_optimizer.cpp new file mode 100644 index 0000000000..db0a07fedc --- /dev/null +++ b/models/weight_optimizer.cpp @@ -0,0 +1,257 @@ +/* + * weight_optimizer.cpp + * + * This file is part of NEST. + * + * Copyright (C) 2004 The NEST Initiative + * + * NEST is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 2 of the License, or + * (at your option) any later version. + * + * NEST is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with NEST. If not, see . + * + */ + +#include "weight_optimizer.h" + +// nestkernel +#include "exceptions.h" +#include "nest_names.h" + +// sli +#include "dictutils.h" + +namespace nest +{ +WeightOptimizerCommonProperties::WeightOptimizerCommonProperties() + : batch_size_( 1 ) + , eta_( 1e-4 ) + , Wmin_( -100.0 ) + , Wmax_( 100.0 ) +{ +} + +WeightOptimizerCommonProperties::WeightOptimizerCommonProperties( const WeightOptimizerCommonProperties& cp ) + : batch_size_( cp.batch_size_ ) + , eta_( cp.eta_ ) + , Wmin_( cp.Wmin_ ) + , Wmax_( cp.Wmax_ ) +{ +} + +void +WeightOptimizerCommonProperties::get_status( DictionaryDatum& d ) const +{ + def< std::string >( d, names::optimizer, get_name() ); + def< long >( d, names::batch_size, batch_size_ ); + def< double >( d, names::eta, eta_ ); + def< double >( d, names::Wmin, Wmin_ ); + def< double >( d, names::Wmax, Wmax_ ); +} + +void +WeightOptimizerCommonProperties::set_status( const DictionaryDatum& d ) +{ + long new_batch_size = batch_size_; + updateValue< long >( d, names::batch_size, new_batch_size ); + if ( new_batch_size <= 0 ) + { + throw BadProperty( "Optimization batch_size > 0 required." ); + } + batch_size_ = new_batch_size; + + double new_eta = eta_; + updateValue< double >( d, names::eta, new_eta ); + if ( new_eta < 0 ) + { + throw BadProperty( "Learning rate eta ≥ 0 required." ); + } + eta_ = new_eta; + + double new_Wmin = Wmin_; + double new_Wmax = Wmax_; + updateValue< double >( d, names::Wmin, new_Wmin ); + updateValue< double >( d, names::Wmax, new_Wmax ); + if ( new_Wmin > new_Wmax ) + { + throw BadProperty( "Minimal weight Wmin ≤ maximal weight Wmax required." ); + } + Wmin_ = new_Wmin; + Wmax_ = new_Wmax; +} + +WeightOptimizer::WeightOptimizer() + : sum_gradients_( 0.0 ) + , optimization_step_( 1 ) +{ +} + +void +WeightOptimizer::get_status( DictionaryDatum& d ) const +{ +} + +void +WeightOptimizer::set_status( const DictionaryDatum& d ) +{ +} + +double +WeightOptimizer::optimized_weight( const WeightOptimizerCommonProperties& cp, + const size_t idx_current_update, + const double gradient, + double weight ) +{ + sum_gradients_ += gradient; + + const size_t current_optimization_step = 1 + idx_current_update / cp.batch_size_; + if ( optimization_step_ < current_optimization_step ) + { + sum_gradients_ /= cp.batch_size_; + weight = std::max( cp.Wmin_, std::min( optimize_( cp, weight, current_optimization_step ), cp.Wmax_ ) ); + optimization_step_ = current_optimization_step; + } + return weight; +} + +WeightOptimizerCommonProperties* +WeightOptimizerCommonPropertiesGradientDescent::clone() const +{ + return new WeightOptimizerCommonPropertiesGradientDescent( *this ); +} + +WeightOptimizer* +WeightOptimizerCommonPropertiesGradientDescent::get_optimizer() const +{ + return new WeightOptimizerGradientDescent(); +} + +WeightOptimizerGradientDescent::WeightOptimizerGradientDescent() + : WeightOptimizer() +{ +} + +double +WeightOptimizerGradientDescent::optimize_( const WeightOptimizerCommonProperties& cp, double weight, size_t ) +{ + weight -= cp.eta_ * sum_gradients_; + sum_gradients_ = 0; + return weight; +} + +WeightOptimizerCommonPropertiesAdam::WeightOptimizerCommonPropertiesAdam() + : WeightOptimizerCommonProperties() + , beta_1_( 0.9 ) + , beta_2_( 0.999 ) + , epsilon_( 1e-8 ) +{ +} + + +WeightOptimizerCommonProperties* +WeightOptimizerCommonPropertiesAdam::clone() const +{ + return new WeightOptimizerCommonPropertiesAdam( *this ); +} + +WeightOptimizer* +WeightOptimizerCommonPropertiesAdam::get_optimizer() const +{ + return new WeightOptimizerAdam(); +} + +void +WeightOptimizerCommonPropertiesAdam::get_status( DictionaryDatum& d ) const +{ + WeightOptimizerCommonProperties::get_status( d ); + + def< double >( d, names::beta_1, beta_1_ ); + def< double >( d, names::beta_2, beta_2_ ); + def< double >( d, names::epsilon, epsilon_ ); +} + +void +WeightOptimizerCommonPropertiesAdam::set_status( const DictionaryDatum& d ) +{ + WeightOptimizerCommonProperties::set_status( d ); + + updateValue< double >( d, names::beta_1, beta_1_ ); + updateValue< double >( d, names::beta_2, beta_2_ ); + updateValue< double >( d, names::epsilon, epsilon_ ); + + if ( beta_1_ < 0.0 or 1.0 <= beta_1_ ) + { + throw BadProperty( "For Adam optimizer, beta_1 from interval [0,1) required." ); + } + + if ( beta_2_ < 0.0 or 1.0 <= beta_2_ ) + { + throw BadProperty( "For Adam optimizer, beta_2 from interval [0,1) required." ); + } + + if ( epsilon_ < 0.0 ) + { + throw BadProperty( "For Adam optimizer, epsilon ≥ 0 required." ); + } +} + +WeightOptimizerAdam::WeightOptimizerAdam() + : WeightOptimizer() + , m_( 0.0 ) + , v_( 0.0 ) +{ +} + +void +WeightOptimizerAdam::get_status( DictionaryDatum& d ) const +{ + WeightOptimizer::get_status( d ); + def< double >( d, names::m, m_ ); + def< double >( d, names::v, v_ ); +} + +void +WeightOptimizerAdam::set_status( const DictionaryDatum& d ) +{ + WeightOptimizer::set_status( d ); + updateValue< double >( d, names::m, m_ ); + updateValue< double >( d, names::v, v_ ); +} + + +double +WeightOptimizerAdam::optimize_( const WeightOptimizerCommonProperties& cp, + double weight, + size_t current_optimization_step ) +{ + const WeightOptimizerCommonPropertiesAdam& acp = dynamic_cast< const WeightOptimizerCommonPropertiesAdam& >( cp ); + + for ( ; optimization_step_ < current_optimization_step; ++optimization_step_ ) + { + const double beta_1_factor = 1.0 - std::pow( acp.beta_1_, optimization_step_ ); + const double beta_2_factor = 1.0 - std::pow( acp.beta_2_, optimization_step_ ); + + const double alpha = cp.eta_ * std::sqrt( beta_2_factor ) / beta_1_factor; + + m_ = acp.beta_1_ * m_ + ( 1.0 - acp.beta_1_ ) * sum_gradients_; + v_ = acp.beta_2_ * v_ + ( 1.0 - acp.beta_2_ ) * sum_gradients_ * sum_gradients_; + + weight -= alpha * m_ / ( std::sqrt( v_ ) + acp.epsilon_ ); + + // set gradients to zero for following iterations since more than + // one cycle indicates past learning periods with vanishing gradients + sum_gradients_ = 0.0; // reset for following iterations + } + + return weight; +} + +} // namespace nest diff --git a/models/weight_optimizer.h b/models/weight_optimizer.h new file mode 100644 index 0000000000..9cacba0745 --- /dev/null +++ b/models/weight_optimizer.h @@ -0,0 +1,359 @@ +/* + * weight_optimizer.h + * + * This file is part of NEST. + * + * Copyright (C) 2004 The NEST Initiative + * + * NEST is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 2 of the License, or + * (at your option) any later version. + * + * NEST is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with NEST. If not, see . + * + */ + +#ifndef WEIGHT_OPTIMIZER_H +#define WEIGHT_OPTIMIZER_H + +// Includes from sli +#include "dictdatum.h" + +namespace nest +{ + +/* BeginUserDocs: e-prop plasticity + +Short description ++++++++++++++++++ + +Selection of weight optimizers + +Description ++++++++++++ +A weight optimizer is an algorithm that adjusts the synaptic weights in a +network during training to minimize the loss function and thus improve the +network's performance on a given task. + +This method is an essential part of plasticity rules like e-prop plasticity. + +Currently two weight optimizers are implemented: gradient descent and the Adam optimizer. + +In gradient descent [1]_ the weights are optimized via: + +.. math:: + W_t = W_{t-1} - \eta \, g_t \,, + +whereby :math:`\eta` denotes the learning rate and :math:`g_t` the gradient of the current +time step :math:`t`. + +In the Adam scheme [2]_ the weights are optimized via: + +.. math:: + m_0 &= 0, v_0 = 0, t = 1 \,, \\ + m_t &= \beta_1 \, m_{t-1} + \left(1-\beta_1\right) \, g_t \,, \\ + v_t &= \beta_2 \, v_{t-1} + \left(1-\beta_2\right) \, g_t^2 \,, \\ + \hat{m}_t &= \frac{m_t}{1-\beta_1^t} \,, \\ + \hat{v}_t &= \frac{v_t}{1-\beta_2^t} \,, \\ + W_t &= W_{t-1} - \eta\frac{\hat{m_t}}{\sqrt{\hat{v}_t} + \epsilon} \,. + +Parameters +++++++++++ + +The following parameters can be set in the status dictionary. + +========== ==== ========================= ======= ================================= +**Common optimizer parameters** +----------------------------------------------------------------------------------- +Parameter Unit Math equivalent Default Description +========== ==== ========================= ======= ================================= +batch_size 1 Size of batch +eta :math:`\eta` 1e-4 Learning rate +Wmax pA :math:`W_{ji}^\text{max}` 100.0 Maximal value for synaptic weight +Wmin pA :math:`W_{ji}^\text{min}` -100.0 Minimal value for synaptic weight +========== ==== ========================= ======= ================================= + +========= ==== =============== ================ ============== +**Gradient descent parameters (default optimizer)** +-------------------------------------------------------------- +Parameter Unit Math equivalent Default Description +========= ==== =============== ================ ============== +type gradient_descent Optimizer type +========= ==== =============== ================ ============== + +========= ==== ================ ======= ================================================= +**Adam optimizer parameters** +----------------------------------------------------------------------------------------- +Parameter Unit Math equivalent Default Description +========= ==== ================ ======= ================================================= +type adam Optimizer type +beta_1 :math:`\beta_1` 0.9 Exponential decay rate for first moment estimate +beta_2 :math:`\beta_2` 0.999 Exponential decay rate for second moment estimate +epsilon :math:`\epsilon` 1e-8 Small constant for numerical stability +========= ==== ================ ======= ================================================= + +The following state variables evolve during simulation. + +============== ==== =============== ============= ========================== +**Adam optimizer state variables for individual synapses** +---------------------------------------------------------------------------- +State variable Unit Math equivalent Initial value Description +============== ==== =============== ============= ========================== +m :math:`m` 0.0 First moment estimate +v :math:`v` 0.0 Second moment raw estimate +============== ==== =============== ============= ========================== + + +References +++++++++++ +.. [1] Huh, D. & Sejnowski, T. J. Gradient descent for spiking neural networks. 32nd + Conference on Neural Information Processing Systems (2018). +.. [2] Kingma DP, Ba JL (2015). Adam: A method for stochastic optimization. + Proceedings of International Conference on Learning Representations (ICLR). + https://doi.org/10.48550/arXiv.1412.6980 + +See also +++++++++ + +Examples using this model +++++++++++++++++++++++++++ + +.. listexamples:: eprop_synapse_bsshslm_2020 + +EndUserDocs */ + +class WeightOptimizer; + +/** + * Base class implementing common properties of a weight optimizer model. + * + * The CommonProperties of synapse models supporting weight optimization own an object of this class hierarchy. + * The values in these objects are used by the synapse-specific optimizer object. + * Change of the optimizer type is only possible before synapses of the model have been created. + */ +class WeightOptimizerCommonProperties +{ +public: + //! Default constructor. + WeightOptimizerCommonProperties(); + + //! Destructor. + virtual ~WeightOptimizerCommonProperties() + { + } + + //! Copy constructor. + WeightOptimizerCommonProperties( const WeightOptimizerCommonProperties& ); + + //! Assignment operator. + WeightOptimizer& operator=( const WeightOptimizer& ) = delete; + + //! Get parameter dictionary. + virtual void get_status( DictionaryDatum& d ) const; + + //! Update parameters in parameter dictionary. + virtual void set_status( const DictionaryDatum& d ); + + //! Clone constructor. + virtual WeightOptimizerCommonProperties* clone() const = 0; + + //! Get optimizer. + virtual WeightOptimizer* get_optimizer() const = 0; + + //! Get minimal value for synaptic weight. + double + get_Wmin() const + { + return Wmin_; + } + + //! Get maximal value for synaptic weight. + double + get_Wmax() const + { + return Wmax_; + } + + //! Get optimizer name. + virtual std::string get_name() const = 0; + +public: + //! Size of an optimization batch. + size_t batch_size_; + + //! Learning rate. + double eta_; + + //! Minimal value for synaptic weight. + double Wmin_; + + //! Maximal value for synaptic weight. + double Wmax_; +}; + +/** + * Base class implementing a weight optimizer model. + * + * An optimizer is used by a synapse that supports this mechanism to optimize the weight. + * + * An optimizer may have an internal state which is maintained from call to call of the `optimized_weight()` method. + * Each optimized object belongs to exactly one synapse. + */ +class WeightOptimizer +{ +public: + //! Default constructor. + WeightOptimizer(); + + //! Destructor. + virtual ~WeightOptimizer() + { + } + + //! Copy constructor. + WeightOptimizer( const WeightOptimizer& ) = default; + + //! Assignment operator. + WeightOptimizer& operator=( const WeightOptimizer& ) = delete; + + //! Get parameter dictionary. + virtual void get_status( DictionaryDatum& d ) const; + + //! Update values in parameter dictionary. + virtual void set_status( const DictionaryDatum& d ); + + //! Return optimized weight based on current weight. + double optimized_weight( const WeightOptimizerCommonProperties& cp, + const size_t idx_current_update, + const double gradient, + double weight ); + +protected: + //! Perform specific optimization. + virtual double optimize_( const WeightOptimizerCommonProperties& cp, double weight, size_t current_opt_step ) = 0; + + //! Sum of gradients accumulated in current batch. + double sum_gradients_; + + //! Current optimization step, whereby optimization happens every batch_size_ steps. + size_t optimization_step_; +}; + +/** + * Base class implementing a gradient descent weight optimizer model. + */ +class WeightOptimizerGradientDescent : public WeightOptimizer +{ +public: + //! Default constructor. + WeightOptimizerGradientDescent(); + + //! Copy constructor. + WeightOptimizerGradientDescent( const WeightOptimizerGradientDescent& ) = default; + + //! Assignment operator. + WeightOptimizerGradientDescent& operator=( const WeightOptimizerGradientDescent& ) = delete; + +private: + double optimize_( const WeightOptimizerCommonProperties& cp, double weight, size_t current_opt_step ) override; +}; + +/** + * Class implementing common properties of a gradient descent weight optimizer model. + */ +class WeightOptimizerCommonPropertiesGradientDescent : public WeightOptimizerCommonProperties +{ + //! Friend class for gradient descent weight optimizer model. + friend class WeightOptimizerGradientDescent; + +public: + //! Assignment operator. + WeightOptimizerCommonPropertiesGradientDescent& operator=( + const WeightOptimizerCommonPropertiesGradientDescent& ) = delete; + + WeightOptimizerCommonProperties* clone() const override; + WeightOptimizer* get_optimizer() const override; + + std::string + get_name() const override + { + return "gradient_descent"; + } +}; + +/** + * Base class implementing an Adam weight optimizer model. + */ +class WeightOptimizerAdam : public WeightOptimizer +{ +public: + //! Default constructor. + WeightOptimizerAdam(); + + //! Copy constructor. + WeightOptimizerAdam( const WeightOptimizerAdam& ) = default; + + //! Assignment operator. + WeightOptimizerAdam& operator=( const WeightOptimizerAdam& ) = delete; + + void get_status( DictionaryDatum& d ) const override; + void set_status( const DictionaryDatum& d ) override; + +private: + double optimize_( const WeightOptimizerCommonProperties& cp, double weight, size_t current_opt_step ) override; + + //! First moment estimate variable. + double m_; + + //! Second moment estimate variable. + double v_; +}; + +/** + * Class implementing common properties of an Adam weight optimizer model. + */ +class WeightOptimizerCommonPropertiesAdam : public WeightOptimizerCommonProperties +{ + //! Friend class for Adam weight optimizer model. + friend class WeightOptimizerAdam; + +public: + //! Default constructor. + WeightOptimizerCommonPropertiesAdam(); + + //! Assignment operator. + WeightOptimizerCommonPropertiesAdam& operator=( const WeightOptimizerCommonPropertiesAdam& ) = delete; + + WeightOptimizerCommonProperties* clone() const override; + WeightOptimizer* get_optimizer() const override; + + void get_status( DictionaryDatum& d ) const override; + void set_status( const DictionaryDatum& d ) override; + + std::string + get_name() const override + { + return "adam"; + } + +private: + //! Exponential decay rate for first moment estimate. + double beta_1_; + + //! Exponential decay rate for second moment estimate. + double beta_2_; + + //! Small constant for numerical stability. + double epsilon_; +}; + +} // namespace nest + +#endif // WEIGHT_OPTIMIZER_H diff --git a/modelsets/eprop b/modelsets/eprop new file mode 100644 index 0000000000..c373835f6b --- /dev/null +++ b/modelsets/eprop @@ -0,0 +1,19 @@ +# Minimal modelset for spiking neuron simulations with e-prop plasticity + +multimeter +spike_recorder +weight_recorder + +spike_generator +step_rate_generator +poisson_generator + +eprop_iaf_bsshslm_2020 +eprop_iaf_adapt_bsshslm_2020 +eprop_readout_bsshslm_2020 +parrot_neuron + +eprop_learning_signal_connection_bsshslm_2020 +eprop_synapse_bsshslm_2020 +rate_connection_delayed +static_synapse diff --git a/modelsets/full b/modelsets/full index be90dc17bf..2fb7bd3c3e 100644 --- a/modelsets/full +++ b/modelsets/full @@ -1,6 +1,5 @@ # Full modelset for NEST including spiking, binary and rate models - ac_generator aeif_cond_alpha aeif_cond_alpha_astro @@ -22,6 +21,11 @@ correlomatrix_detector correlospinmatrix_detector dc_generator diffusion_connection +eprop_iaf_bsshslm_2020 +eprop_iaf_adapt_bsshslm_2020 +eprop_readout_bsshslm_2020 +eprop_synapse_bsshslm_2020 +eprop_learning_signal_connection_bsshslm_2020 erfc_neuron gamma_sup_generator gap_junction diff --git a/modelsets/iaf_minimal b/modelsets/iaf_minimal index 1d8387da7e..660657cdc3 100644 --- a/modelsets/iaf_minimal +++ b/modelsets/iaf_minimal @@ -1,6 +1,5 @@ # Minimal modelset for spiking neuron simulations with NEST - dc_generator iaf_psc_alpha multimeter diff --git a/nestkernel/CMakeLists.txt b/nestkernel/CMakeLists.txt index 693e92f111..b354fb0791 100644 --- a/nestkernel/CMakeLists.txt +++ b/nestkernel/CMakeLists.txt @@ -23,6 +23,7 @@ set ( nestkernel_sources archiving_node.h archiving_node.cpp clopath_archiving_node.h clopath_archiving_node.cpp urbanczik_archiving_node.h urbanczik_archiving_node_impl.h + eprop_archiving_node.h eprop_archiving_node_impl.h eprop_archiving_node.cpp common_synapse_properties.h common_synapse_properties.cpp connection.h connection_label.h diff --git a/nestkernel/common_synapse_properties.h b/nestkernel/common_synapse_properties.h index 4923d83e32..fab02945c6 100644 --- a/nestkernel/common_synapse_properties.h +++ b/nestkernel/common_synapse_properties.h @@ -45,7 +45,7 @@ class TimeConverter; * Class containing the common properties for all connections of a certain type. * * Everything that needs to be stored commonly for all synapses goes into a - * CommonProperty class derived by this base class. + * CommonProperty class derived from this base class. * Base class for all CommonProperty classes. * If the synapse type does not have any common properties, this class may be * used as a placeholder. diff --git a/nestkernel/connection_manager.cpp b/nestkernel/connection_manager.cpp index 146f6f2081..6dd7c012fb 100644 --- a/nestkernel/connection_manager.cpp +++ b/nestkernel/connection_manager.cpp @@ -46,6 +46,7 @@ #include "connector_base.h" #include "connector_model.h" #include "delay_checker.h" +#include "eprop_archiving_node.h" #include "exceptions.h" #include "kernel_manager.h" #include "mpi_manager_impl.h" @@ -855,6 +856,14 @@ nest::ConnectionManager::connect_( Node& source, throw NotImplemented( "This synapse model is not supported by the neuron model of at least one connection." ); } + const bool eprop_archiving = conn_model.has_property( ConnectionModelProperties::REQUIRES_EPROP_ARCHIVING ); + if ( eprop_archiving + and not( dynamic_cast< EpropArchivingNodeRecurrent* >( &target ) + or dynamic_cast< EpropArchivingNodeReadout* >( &target ) ) ) + { + throw NotImplemented( "This synapse model is not supported by the neuron model of at least one connection." ); + } + const bool is_primary = conn_model.has_property( ConnectionModelProperties::IS_PRIMARY ); conn_model.add_connection( source, target, connections_[ tid ], syn_id, params, delay, weight ); source_table_.add_source( tid, syn_id, s_node_id, is_primary ); diff --git a/nestkernel/connector_base.h b/nestkernel/connector_base.h index a7f5cf062f..7cdd91b1e8 100644 --- a/nestkernel/connector_base.h +++ b/nestkernel/connector_base.h @@ -56,6 +56,10 @@ namespace nest /** * Base class to allow storing Connectors for different synapse types * in vectors. We define the interface here to avoid casting. + * + * @note If any member functions need to do something special for a given connection type, + * declare specializations in the corresponding header file and define them in the corresponding + * source file. For an example, see `eprop_synapse_bsshslm_2020`. */ class ConnectorBase { @@ -268,6 +272,12 @@ class Connector : public ConnectorBase C_.push_back( c ); } + void + push_back( ConnectionT&& c ) + { + C_.push_back( std::move( c ) ); + } + void get_connection( const size_t source_node_id, const size_t target_node_id, diff --git a/nestkernel/connector_model.h b/nestkernel/connector_model.h index 504f8a8e80..0a7f83ce8e 100644 --- a/nestkernel/connector_model.h +++ b/nestkernel/connector_model.h @@ -57,7 +57,8 @@ enum class ConnectionModelProperties : unsigned SUPPORTS_WFR = 1 << 4, REQUIRES_SYMMETRIC = 1 << 5, REQUIRES_CLOPATH_ARCHIVING = 1 << 6, - REQUIRES_URBANCZIK_ARCHIVING = 1 << 7 + REQUIRES_URBANCZIK_ARCHIVING = 1 << 7, + REQUIRES_EPROP_ARCHIVING = 1 << 8 }; template <> @@ -120,6 +121,7 @@ class ConnectorModel virtual SecondaryEvent* get_secondary_event() = 0; + virtual size_t get_syn_id() const = 0; virtual void set_syn_id( synindex syn_id ) = 0; std::string @@ -192,6 +194,7 @@ class GenericConnectorModel : public ConnectorModel return cp_; } + size_t get_syn_id() const override; void set_syn_id( synindex syn_id ) override; void check_synapse_params( const DictionaryDatum& syn_spec ) const override; diff --git a/nestkernel/connector_model_impl.h b/nestkernel/connector_model_impl.h index 73a3969508..fc831cb916 100644 --- a/nestkernel/connector_model_impl.h +++ b/nestkernel/connector_model_impl.h @@ -203,6 +203,13 @@ GenericConnectorModel< ConnectionT >::used_default_delay() } } +template < typename ConnectionT > +size_t +GenericConnectorModel< ConnectionT >::get_syn_id() const +{ + return default_connection_.get_syn_id(); +} + template < typename ConnectionT > void GenericConnectorModel< ConnectionT >::set_syn_id( synindex syn_id ) @@ -311,7 +318,7 @@ GenericConnectorModel< ConnectionT >::add_connection_( Node& src, assert( connector ); Connector< ConnectionT >* vc = static_cast< Connector< ConnectionT >* >( connector ); - vc->push_back( connection ); + vc->push_back( std::move( connection ) ); } } // namespace nest diff --git a/nestkernel/eprop_archiving_node.cpp b/nestkernel/eprop_archiving_node.cpp new file mode 100644 index 0000000000..f607ca34ec --- /dev/null +++ b/nestkernel/eprop_archiving_node.cpp @@ -0,0 +1,165 @@ +/* + * eprop_archiving_node.cpp + * + * This file is part of NEST. + * + * Copyright (C) 2004 The NEST Initiative + * + * NEST is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 2 of the License, or + * (at your option) any later version. + * + * NEST is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with NEST. If not, see . + * + */ + +// nestkernel +#include "eprop_archiving_node.h" +#include "eprop_archiving_node_impl.h" +#include "kernel_manager.h" + +// sli +#include "dictutils.h" + +namespace nest +{ + +EpropArchivingNodeRecurrent::EpropArchivingNodeRecurrent() + : EpropArchivingNode() + , n_spikes_( 0 ) +{ +} + +EpropArchivingNodeRecurrent::EpropArchivingNodeRecurrent( const EpropArchivingNodeRecurrent& n ) + : EpropArchivingNode( n ) + , n_spikes_( n.n_spikes_ ) +{ +} + +void +EpropArchivingNodeRecurrent::write_surrogate_gradient_to_history( const long time_step, + const double surrogate_gradient ) +{ + if ( eprop_indegree_ == 0 ) + { + return; + } + + eprop_history_.emplace_back( time_step, surrogate_gradient, 0.0 ); +} + +void +EpropArchivingNodeRecurrent::write_learning_signal_to_history( const long time_step, const double learning_signal ) +{ + if ( eprop_indegree_ == 0 ) + { + return; + } + + const long shift = delay_rec_out_ + delay_out_norm_ + delay_out_rec_; + + auto it_hist = get_eprop_history( time_step - shift ); + const auto it_hist_end = get_eprop_history( time_step - shift + delay_out_rec_ ); + + for ( ; it_hist != it_hist_end; ++it_hist ) + { + it_hist->learning_signal_ += learning_signal; + } +} + +void +EpropArchivingNodeRecurrent::write_firing_rate_reg_to_history( const long t_current_update, + const double f_target, + const double c_reg ) +{ + if ( eprop_indegree_ == 0 ) + { + return; + } + + const double update_interval = kernel().simulation_manager.get_eprop_update_interval().get_steps(); + const double dt = Time::get_resolution().get_ms(); + const long shift = Time::get_resolution().get_steps(); + + const double f_av = n_spikes_ / update_interval; + const double f_target_ = f_target * dt; // convert from spikes/ms to spikes/step + const double firing_rate_reg = c_reg * ( f_av - f_target_ ) / update_interval; + + firing_rate_reg_history_.emplace_back( t_current_update + shift, firing_rate_reg ); +} + +std::vector< HistEntryEpropFiringRateReg >::iterator +EpropArchivingNodeRecurrent::get_firing_rate_reg_history( const long time_step ) +{ + const auto it_hist = std::lower_bound( firing_rate_reg_history_.begin(), firing_rate_reg_history_.end(), time_step ); + assert( it_hist != firing_rate_reg_history_.end() ); + + return it_hist; +} + +double +EpropArchivingNodeRecurrent::get_learning_signal_from_history( const long time_step ) +{ + const long shift = delay_rec_out_ + delay_out_norm_ + delay_out_rec_; + + const auto it = get_eprop_history( time_step - shift ); + if ( it == eprop_history_.end() ) + { + return 0; + } + + return it->learning_signal_; +} + +void +EpropArchivingNodeRecurrent::erase_used_firing_rate_reg_history() +{ + auto it_update_hist = update_history_.begin(); + auto it_reg_hist = firing_rate_reg_history_.begin(); + + while ( it_update_hist != update_history_.end() and it_reg_hist != firing_rate_reg_history_.end() ) + { + if ( it_update_hist->access_counter_ == 0 ) + { + it_reg_hist = firing_rate_reg_history_.erase( it_reg_hist ); + } + else + { + ++it_reg_hist; + } + ++it_update_hist; + } +} + +EpropArchivingNodeReadout::EpropArchivingNodeReadout() + : EpropArchivingNode() +{ +} + +EpropArchivingNodeReadout::EpropArchivingNodeReadout( const EpropArchivingNodeReadout& n ) + : EpropArchivingNode( n ) +{ +} + +void +EpropArchivingNodeReadout::write_error_signal_to_history( const long time_step, const double error_signal ) +{ + if ( eprop_indegree_ == 0 ) + { + return; + } + + const long shift = delay_out_norm_; + + eprop_history_.emplace_back( time_step - shift, error_signal ); +} + + +} // namespace nest diff --git a/nestkernel/eprop_archiving_node.h b/nestkernel/eprop_archiving_node.h new file mode 100644 index 0000000000..187b22341b --- /dev/null +++ b/nestkernel/eprop_archiving_node.h @@ -0,0 +1,184 @@ +/* + * eprop_archiving_node.h + * + * This file is part of NEST. + * + * Copyright (C) 2004 The NEST Initiative + * + * NEST is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 2 of the License, or + * (at your option) any later version. + * + * NEST is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with NEST. If not, see . + * + */ + +#ifndef EPROP_ARCHIVING_NODE_H +#define EPROP_ARCHIVING_NODE_H + +// nestkernel +#include "histentry.h" +#include "nest_time.h" +#include "nest_types.h" +#include "node.h" + +// sli +#include "dictdatum.h" + +namespace nest +{ + +/** + * Base class implementing an intermediate archiving node model for node models supporting e-prop plasticity. + * + * A node which archives the history of dynamic variables, the firing rate + * regularization, and update times needed to calculate the weight updates for + * e-prop plasticity. It further provides a set of get, write, and set functions + * for these histories and the hardcoded shifts to synchronize the factors of + * the plasticity rule. + */ +template < typename HistEntryT > +class EpropArchivingNode : public Node +{ + +public: + //! Default constructor. + EpropArchivingNode(); + + //! Copy constructor. + EpropArchivingNode( const EpropArchivingNode& ); + + //! Initialize the update history and register the eprop synapse. + void register_eprop_connection() override; + + //! Register current update in the update history and deregister previous update. + void write_update_to_history( const long t_previous_update, const long t_current_update ) override; + + //! Get an iterator pointing to the update history entry of the given time step. + std::vector< HistEntryEpropUpdate >::iterator get_update_history( const long time_step ); + + //! Get an iterator pointing to the eprop history entry of the given time step. + typename std::vector< HistEntryT >::iterator get_eprop_history( const long time_step ); + + //! Erase update history parts for which the access counter has decreased to zero since no synapse needs them + //! any longer. + void erase_used_update_history(); + + //! Erase update intervals from the e-prop history in which each synapse has either not transmitted a spike or has + //! transmitted a spike in a more recent update interval. + void erase_used_eprop_history(); + +protected: + //!< Number of incoming eprop synapses + size_t eprop_indegree_; + + //! History of updates still needed by at least one synapse. + std::vector< HistEntryEpropUpdate > update_history_; + + //! History of dynamic variables needed for e-prop plasticity. + std::vector< HistEntryT > eprop_history_; + + // The following shifts are, for now, hardcoded to 1 time step since the current + // implementation only works if all the delays are equal to the simulation resolution. + + //! Offset since generator signals start from time step 1. + const long offset_gen_ = 1; + + //! Connection delay from input to recurrent neurons. + const long delay_in_rec_ = 1; + + //! Connection delay from recurrent to output neurons. + const long delay_rec_out_ = 1; + + //! Connection delay between output neurons for normalization. + const long delay_out_norm_ = 1; + + //! Connection delay from output neurons to recurrent neurons. + const long delay_out_rec_ = 1; +}; + +/** + * Class implementing an intermediate archiving node model for recurrent node models supporting e-prop plasticity. + */ +class EpropArchivingNodeRecurrent : public EpropArchivingNode< HistEntryEpropRecurrent > +{ + +public: + //! Default constructor. + EpropArchivingNodeRecurrent(); + + //! Copy constructor. + EpropArchivingNodeRecurrent( const EpropArchivingNodeRecurrent& ); + + //! Create an entry in the eprop history for the given time step and surrogate gradient. + void write_surrogate_gradient_to_history( const long time_step, const double surrogate_gradient ); + + //! Update the learning signal in the eprop history entry of the given time step by writing the value of the incoming + //! learning signal to the history or adding it to the existing value in case of multiple readout neurons. + void write_learning_signal_to_history( const long time_step, const double learning_signal ); + + //! Create an entry in the firing rate regularization history for the current update. + void write_firing_rate_reg_to_history( const long t_current_update, const double f_target, const double c_reg ); + + //! Get an iterator pointing to the firing rate regularization history of the given time step. + std::vector< HistEntryEpropFiringRateReg >::iterator get_firing_rate_reg_history( const long time_step ); + + //! Return learning signal from history for given time step or zero if time step not in history + double get_learning_signal_from_history( const long time_step ); + + //! Erase parts of the firing rate regularization history for which the access counter in the update history has + //! decreased to zero since no synapse needs them any longer. + void erase_used_firing_rate_reg_history(); + + //! Count emitted spike for the firing rate regularization. + void count_spike(); + + //! Reset spike count for the firing rate regularization. + void reset_spike_count(); + +private: + //! Count of the emitted spikes for the firing rate regularization. + size_t n_spikes_; + + //! History of the firing rate regularization. + std::vector< HistEntryEpropFiringRateReg > firing_rate_reg_history_; +}; + +inline void +EpropArchivingNodeRecurrent::count_spike() +{ + ++n_spikes_; +} + +inline void +EpropArchivingNodeRecurrent::reset_spike_count() +{ + n_spikes_ = 0; +} + +/** + * Class implementing an intermediate archiving node model for readout node models supporting e-prop plasticity. + */ +class EpropArchivingNodeReadout : public EpropArchivingNode< HistEntryEpropReadout > +{ +public: + //! Default constructor. + EpropArchivingNodeReadout(); + + //! Copy constructor. + EpropArchivingNodeReadout( const EpropArchivingNodeReadout& ); + + //! Create an entry in the eprop history for the given time step and error signal. + void write_error_signal_to_history( const long time_step, const double error_signal ); +}; + +} // namespace nest + +#endif // EPROP_ARCHIVING_NODE_H diff --git a/nestkernel/eprop_archiving_node_impl.h b/nestkernel/eprop_archiving_node_impl.h new file mode 100644 index 0000000000..e2798337a5 --- /dev/null +++ b/nestkernel/eprop_archiving_node_impl.h @@ -0,0 +1,172 @@ +/* + * eprop_archiving_node_impl.h + * + * This file is part of NEST. + * + * Copyright (C) 2004 The NEST Initiative + * + * NEST is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 2 of the License, or + * (at your option) any later version. + * + * NEST is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with NEST. If not, see . + * + */ + +#ifndef EPROP_ARCHIVING_NODE_IMPL_H +#define EPROP_ARCHIVING_NODE_IMPL_H + +#include "eprop_archiving_node.h" + +// Includes from nestkernel: +#include "kernel_manager.h" + +// Includes from sli: +#include "dictutils.h" + +namespace nest +{ + +template < typename HistEntryT > +EpropArchivingNode< HistEntryT >::EpropArchivingNode() + : Node() + , eprop_indegree_( 0 ) +{ +} + +template < typename HistEntryT > +EpropArchivingNode< HistEntryT >::EpropArchivingNode( const EpropArchivingNode& n ) + : Node( n ) + , eprop_indegree_( n.eprop_indegree_ ) +{ +} + +template < typename HistEntryT > +void +EpropArchivingNode< HistEntryT >::register_eprop_connection() +{ + ++eprop_indegree_; + + const long shift = get_shift(); + + const auto it_hist = get_update_history( shift ); + + if ( it_hist == update_history_.end() or it_hist->t_ != shift ) + { + update_history_.insert( it_hist, HistEntryEpropUpdate( shift, 1 ) ); + } + else + { + ++it_hist->access_counter_; + } +} + +template < typename HistEntryT > +void +EpropArchivingNode< HistEntryT >::write_update_to_history( const long t_previous_update, const long t_current_update ) +{ + if ( eprop_indegree_ == 0 ) + { + return; + } + + const long shift = get_shift(); + + const auto it_hist_curr = get_update_history( t_current_update + shift ); + + if ( it_hist_curr != update_history_.end() and it_hist_curr->t_ == t_current_update + shift ) + { + ++it_hist_curr->access_counter_; + } + else + { + update_history_.insert( it_hist_curr, HistEntryEpropUpdate( t_current_update + shift, 1 ) ); + } + + const auto it_hist_prev = get_update_history( t_previous_update + shift ); + + if ( it_hist_prev != update_history_.end() and it_hist_prev->t_ == t_previous_update + shift ) + { + // If an entry exists for the previous update time, decrement its access counter + --it_hist_prev->access_counter_; + } +} + +template < typename HistEntryT > +std::vector< HistEntryEpropUpdate >::iterator +EpropArchivingNode< HistEntryT >::get_update_history( const long time_step ) +{ + return std::lower_bound( update_history_.begin(), update_history_.end(), time_step ); +} + +template < typename HistEntryT > +typename std::vector< HistEntryT >::iterator +EpropArchivingNode< HistEntryT >::get_eprop_history( const long time_step ) +{ + return std::lower_bound( eprop_history_.begin(), eprop_history_.end(), time_step ); +} + +template < typename HistEntryT > +void +EpropArchivingNode< HistEntryT >::erase_used_eprop_history() +{ + if ( eprop_history_.empty() // nothing to remove + or update_history_.empty() // no time markers to check + ) + { + return; + } + + const long update_interval = kernel().simulation_manager.get_eprop_update_interval().get_steps(); + + auto it_update_hist = update_history_.begin(); + + for ( long t = update_history_.begin()->t_; + t <= ( update_history_.end() - 1 )->t_ and it_update_hist != update_history_.end(); + t += update_interval ) + { + if ( it_update_hist->t_ == t ) + { + ++it_update_hist; + } + else + { + const auto it_eprop_hist_from = get_eprop_history( t ); + const auto it_eprop_hist_to = get_eprop_history( t + update_interval ); + eprop_history_.erase( it_eprop_hist_from, it_eprop_hist_to ); // erase found entries since no longer used + } + } + const auto it_eprop_hist_from = get_eprop_history( 0 ); + const auto it_eprop_hist_to = get_eprop_history( update_history_.begin()->t_ ); + eprop_history_.erase( it_eprop_hist_from, it_eprop_hist_to ); // erase found entries since no longer used +} + +template < typename HistEntryT > +void +EpropArchivingNode< HistEntryT >::erase_used_update_history() +{ + auto it_hist = update_history_.begin(); + while ( it_hist != update_history_.end() ) + { + if ( it_hist->access_counter_ == 0 ) + { + // erase() invalidates the iterator, but returns a new, valid iterator + it_hist = update_history_.erase( it_hist ); + } + else + { + ++it_hist; + } + } +} + +} // namespace nest + +#endif // EPROP_ARCHIVING_NODE_IMPL_H diff --git a/nestkernel/event.cpp b/nestkernel/event.cpp index 587153505b..6d33a5f615 100644 --- a/nestkernel/event.cpp +++ b/nestkernel/event.cpp @@ -151,6 +151,12 @@ DiffusionConnectionEvent::operator()() receiver_->handle( *this ); } +void +LearningSignalConnectionEvent::operator()() +{ + receiver_->handle( *this ); +} + void SICEvent::operator()() { diff --git a/nestkernel/event.h b/nestkernel/event.h index da36b63edc..13860bd88a 100644 --- a/nestkernel/event.h +++ b/nestkernel/event.h @@ -97,6 +97,7 @@ class Node; * @see DiffusionConnectionEvent * @see GapJunctionEvent * @see InstantaneousRateConnectionEvent + * @see LearningSignalConnectionEvent * @ingroup event_interface */ diff --git a/nestkernel/genericmodel.h b/nestkernel/genericmodel.h index 91509e4d0a..a6f5ddc9cb 100644 --- a/nestkernel/genericmodel.h +++ b/nestkernel/genericmodel.h @@ -81,6 +81,8 @@ class GenericModel : public Model void sends_secondary_event( DelayedRateConnectionEvent& re ) override; + void sends_secondary_event( LearningSignalConnectionEvent& re ) override; + void sends_secondary_event( SICEvent& sic ) override; Node const& get_prototype() const override; @@ -215,6 +217,13 @@ GenericModel< ElementT >::sends_secondary_event( DelayedRateConnectionEvent& re return proto_.sends_secondary_event( re ); } +template < typename ElementT > +inline void +GenericModel< ElementT >::sends_secondary_event( LearningSignalConnectionEvent& re ) +{ + return proto_.sends_secondary_event( re ); +} + template < typename ElementT > inline void GenericModel< ElementT >::sends_secondary_event( SICEvent& sic ) diff --git a/nestkernel/histentry.cpp b/nestkernel/histentry.cpp index 03fb02133f..b56bb1ce87 100644 --- a/nestkernel/histentry.cpp +++ b/nestkernel/histentry.cpp @@ -36,3 +36,33 @@ nest::histentry_extended::histentry_extended( double t, double dw, size_t access , access_counter_( access_counter ) { } + +nest::HistEntryEprop::HistEntryEprop( long t ) + : t_( t ) +{ +} + +nest::HistEntryEpropRecurrent::HistEntryEpropRecurrent( long t, double surrogate_gradient, double learning_signal ) + : HistEntryEprop( t ) + , surrogate_gradient_( surrogate_gradient ) + , learning_signal_( learning_signal ) +{ +} + +nest::HistEntryEpropReadout::HistEntryEpropReadout( long t, double error_signal ) + : HistEntryEprop( t ) + , error_signal_( error_signal ) +{ +} + +nest::HistEntryEpropUpdate::HistEntryEpropUpdate( long t, size_t access_counter ) + : HistEntryEprop( t ) + , access_counter_( access_counter ) +{ +} + +nest::HistEntryEpropFiringRateReg::HistEntryEpropFiringRateReg( long t, double firing_rate_reg ) + : HistEntryEprop( t ) + , firing_rate_reg_( firing_rate_reg ) +{ +} diff --git a/nestkernel/histentry.h b/nestkernel/histentry.h index 8e148df6fd..0d5b1392bf 100644 --- a/nestkernel/histentry.h +++ b/nestkernel/histentry.h @@ -49,9 +49,86 @@ class histentry_extended public: histentry_extended( double t, double dw, size_t access_counter ); - double t_; //!< point in time when spike occurred (in ms) - double dw_; //!< value dependend on the additional factor - size_t access_counter_; //!< access counter to enable removal of the entry, once all neurons read it + double t_; //!< point in time when spike occurred (in ms) + double dw_; + //! how often this entry was accessed (to enable removal, once read by all + //! neurons which need it) + size_t access_counter_; + + friend bool operator<( const histentry_extended he, double t ); +}; + +inline bool +operator<( const histentry_extended he, double t ) +{ + return he.t_ < t; +} + +/** + * Base class implementing history entries for e-prop plasticity. + */ +class HistEntryEprop +{ +public: + HistEntryEprop( long t ); + + long t_; + virtual ~HistEntryEprop() + { + } + + friend bool operator<( const HistEntryEprop& he, long t ); +}; + +inline bool +operator<( const HistEntryEprop& he, long t ) +{ + return he.t_ < t; +} + +/** + * Class implementing entries of the recurrent node model's history of e-prop dynamic variables. + */ +class HistEntryEpropRecurrent : public HistEntryEprop +{ +public: + HistEntryEpropRecurrent( long t, double surrogate_gradient, double learning_signal ); + + double surrogate_gradient_; + double learning_signal_; +}; + +/** + * Class implementing entries of the readout node model's history of e-prop dynamic variables. + */ +class HistEntryEpropReadout : public HistEntryEprop +{ +public: + HistEntryEpropReadout( long t, double error_signal ); + + double error_signal_; +}; + +/** + * Class implementing entries of the update history for e-prop plasticity. + */ +class HistEntryEpropUpdate : public HistEntryEprop +{ +public: + HistEntryEpropUpdate( long t, size_t access_counter ); + + size_t access_counter_; +}; + +/** + * Class implementing entries of the firing rate regularization history for e-prop plasticity. + */ +class HistEntryEpropFiringRateReg : public HistEntryEprop +{ +public: + HistEntryEpropFiringRateReg( long t, double firing_rate_reg ); + + double firing_rate_reg_; }; } diff --git a/nestkernel/model.h b/nestkernel/model.h index 7bd8942a62..19372e9a6a 100644 --- a/nestkernel/model.h +++ b/nestkernel/model.h @@ -159,6 +159,7 @@ class Model virtual void sends_secondary_event( InstantaneousRateConnectionEvent& re ) = 0; virtual void sends_secondary_event( DiffusionConnectionEvent& de ) = 0; virtual void sends_secondary_event( DelayedRateConnectionEvent& re ) = 0; + virtual void sends_secondary_event( LearningSignalConnectionEvent& re ) = 0; virtual void sends_secondary_event( SICEvent& sic ) = 0; /** diff --git a/nestkernel/nest_names.cpp b/nestkernel/nest_names.cpp index b0364f2584..2d6b867aa7 100644 --- a/nestkernel/nest_names.cpp +++ b/nestkernel/nest_names.cpp @@ -47,6 +47,9 @@ const Name a_thresh_th( "a_thresh_th" ); const Name a_thresh_tl( "a_thresh_tl" ); const Name acceptable_latency( "acceptable_latency" ); const Name activity( "activity" ); +const Name adapt_beta( "adapt_beta" ); +const Name adapt_tau( "adapt_tau" ); +const Name adaptation( "adaptation" ); const Name adapting_threshold( "adapting_threshold" ); const Name adaptive_target_buffers( "adaptive_target_buffers" ); const Name add_compartments( "add_compartments" ); @@ -71,10 +74,14 @@ const Name asc_decay( "asc_decay" ); const Name asc_init( "asc_init" ); const Name asc_r( "asc_r" ); const Name available( "available" ); +const Name average_gradient( "average_gradient" ); const Name azimuth_angle( "azimuth_angle" ); const Name b( "b" ); +const Name batch_size( "batch_size" ); const Name beta( "beta" ); +const Name beta_1( "beta_1" ); +const Name beta_2( "beta_2" ); const Name beta_Ca( "beta_Ca" ); const Name biological_time( "biological_time" ); const Name box( "box" ); @@ -90,6 +97,7 @@ const Name c( "c" ); const Name c_1( "c_1" ); const Name c_2( "c_2" ); const Name c_3( "c_3" ); +const Name c_reg( "c_reg" ); const Name capacity( "capacity" ); const Name center( "center" ); const Name circular( "circular" ); @@ -167,12 +175,18 @@ const Name elements( "elements" ); const Name elementsize( "elementsize" ); const Name ellipsoidal( "ellipsoidal" ); const Name elliptical( "elliptical" ); +const Name eprop_learning_window( "eprop_learning_window" ); +const Name eprop_reset_neurons_on_update( "eprop_reset_neurons_on_update" ); +const Name eprop_update_interval( "eprop_update_interval" ); const Name eps( "eps" ); +const Name epsilon( "epsilon" ); const Name equilibrate( "equilibrate" ); +const Name error_signal( "error_signal" ); const Name eta( "eta" ); const Name events( "events" ); const Name extent( "extent" ); +const Name f_target( "f_target" ); const Name file_extension( "file_extension" ); const Name filename( "filename" ); const Name filenames( "filenames" ); @@ -211,6 +225,7 @@ const Name g_ps( "g_ps" ); const Name g_rr( "g_rr" ); const Name g_sfa( "g_sfa" ); const Name g_sp( "g_sp" ); +const Name gamma( "gamma" ); const Name gamma_shape( "gamma_shape" ); const Name gaussian( "gaussian" ); const Name global_id( "global_id" ); @@ -271,6 +286,7 @@ const Name kernel( "kernel" ); const Name label( "label" ); const Name lambda( "lambda" ); const Name lambda_0( "lambda_0" ); +const Name learning_signal( "learning_signal" ); const Name len_kernel( "len_kernel" ); const Name linear( "linear" ); const Name linear_summation( "linear_summation" ); @@ -280,8 +296,10 @@ const Name local_spike_counter( "local_spike_counter" ); const Name lookuptable_0( "lookuptable_0" ); const Name lookuptable_1( "lookuptable_1" ); const Name lookuptable_2( "lookuptable_2" ); +const Name loss( "loss" ); const Name lower_left( "lower_left" ); +const Name m( "m" ); const Name major_axis( "major_axis" ); const Name make_symmetric( "make_symmetric" ); const Name mask( "mask" ); @@ -335,6 +353,7 @@ const Name off_grid_spiking( "off_grid_spiking" ); const Name offset( "offset" ); const Name offsets( "offsets" ); const Name omega( "omega" ); +const Name optimizer( "optimizer" ); const Name order( "order" ); const Name origin( "origin" ); const Name other( "other" ); @@ -394,6 +413,8 @@ const Name rate_times( "rate_times" ); const Name rate_values( "rate_values" ); const Name ratio_ER_cyt( "ratio_ER_cyt" ); const Name readout_cycle_duration( "readout_cycle_duration" ); +const Name readout_signal( "readout_signal" ); +const Name readout_signal_unnorm( "readout_signal_unnorm" ); const Name receptor_idx( "receptor_idx" ); const Name receptor_type( "receptor_type" ); const Name receptor_types( "receptor_types" ); @@ -409,6 +430,7 @@ const Name rectify_rate( "rectify_rate" ); const Name recv_buffer_size_secondary_events( "recv_buffer_size_secondary_events" ); const Name refractory_input( "refractory_input" ); const Name registered( "registered" ); +const Name regular_spike_arrival( "regular_spike_arrival" ); const Name relative_amplitude( "relative_amplitude" ); const Name requires_symmetric( "requires_symmetric" ); const Name reset_pattern( "reset_pattern" ); @@ -459,6 +481,8 @@ const Name stimulus_source( "stimulus_source" ); const Name stop( "stop" ); const Name structural_plasticity_synapses( "structural_plasticity_synapses" ); const Name structural_plasticity_update_interval( "structural_plasticity_update_interval" ); +const Name surrogate_gradient( "surrogate_gradient" ); +const Name surrogate_gradient_function( "surrogate_gradient_function" ); const Name synapse_id( "synapse_id" ); const Name synapse_label( "synapse_label" ); const Name synapse_model( "synapse_model" ); @@ -481,6 +505,7 @@ const Name t_ref_remaining( "t_ref_remaining" ); const Name t_ref_tot( "t_ref_tot" ); const Name t_spike( "t_spike" ); const Name target( "target" ); +const Name target_signal( "target_signal" ); const Name target_thread( "target_thread" ); const Name targets( "targets" ); const Name tau( "tau" ); @@ -508,6 +533,7 @@ const Name tau_max( "tau_max" ); const Name tau_minus( "tau_minus" ); const Name tau_minus_stdp( "tau_minus_stdp" ); const Name tau_minus_triplet( "tau_minus_triplet" ); +const Name tau_m_readout( "tau_m_readout" ); const Name tau_n( "tau_n" ); const Name tau_P( "tau_P" ); const Name tau_plus( "tau_plus" ); @@ -577,6 +603,7 @@ const Name time_update( "time_update" ); const Name times( "times" ); const Name to_do( "to_do" ); const Name total_num_virtual_procs( "total_num_virtual_procs" ); +const Name type( "type" ); const Name type_id( "type_id" ); const Name U( "U" ); @@ -591,6 +618,7 @@ const Name upper_right( "upper_right" ); const Name use_compressed_spikes( "use_compressed_spikes" ); const Name use_wfr( "use_wfr" ); +const Name v( "v" ); const Name V_act_NMDA( "V_act_NMDA" ); const Name V_clamp( "V_clamp" ); const Name V_epsp( "V_epsp" ); @@ -602,6 +630,7 @@ const Name V_reset( "V_reset" ); const Name V_T( "V_T" ); const Name V_T_star( "V_T_star" ); const Name V_th( "V_th" ); +const Name V_th_adapt( "V_th_adapt" ); const Name V_th_alpha_1( "V_th_alpha_1" ); const Name V_th_alpha_2( "V_th_alpha_2" ); const Name V_th_max( "V_th_max" ); diff --git a/nestkernel/nest_names.h b/nestkernel/nest_names.h index 81b31a6056..b07bda0cb7 100644 --- a/nestkernel/nest_names.h +++ b/nestkernel/nest_names.h @@ -73,6 +73,9 @@ extern const Name a_thresh_th; extern const Name a_thresh_tl; extern const Name acceptable_latency; extern const Name activity; +extern const Name adapt_beta; +extern const Name adapt_tau; +extern const Name adaptation; extern const Name adapting_threshold; extern const Name adaptive_target_buffers; extern const Name add_compartments; @@ -97,10 +100,15 @@ extern const Name asc_decay; extern const Name asc_init; extern const Name asc_r; extern const Name available; +extern const Name average_gradient; extern const Name azimuth_angle; extern const Name b; +extern const Name batch_size; extern const Name beta; +extern const Name beta_1; +extern const Name beta_2; + extern const Name beta_Ca; extern const Name biological_time; extern const Name box; @@ -116,6 +124,7 @@ extern const Name c; extern const Name c_1; extern const Name c_2; extern const Name c_3; +extern const Name c_reg; extern const Name capacity; extern const Name center; extern const Name circular; @@ -193,12 +202,19 @@ extern const Name elements; extern const Name elementsize; extern const Name ellipsoidal; extern const Name elliptical; +extern const Name eprop_learning_window; +extern const Name eprop_reset_neurons_on_update; +extern const Name eprop_update_interval; extern const Name eps; +extern const Name epsilon; + extern const Name equilibrate; +extern const Name error_signal; extern const Name eta; extern const Name events; extern const Name extent; +extern const Name f_target; extern const Name file_extension; extern const Name filename; extern const Name filenames; @@ -237,6 +253,7 @@ extern const Name g_ps; extern const Name g_rr; extern const Name g_sfa; extern const Name g_sp; +extern const Name gamma; extern const Name gamma_shape; extern const Name gaussian; extern const Name global_id; @@ -297,6 +314,7 @@ extern const Name kernel; extern const Name label; extern const Name lambda; extern const Name lambda_0; +extern const Name learning_signal; extern const Name len_kernel; extern const Name linear; extern const Name linear_summation; @@ -306,8 +324,10 @@ extern const Name local_spike_counter; extern const Name lookuptable_0; extern const Name lookuptable_1; extern const Name lookuptable_2; +extern const Name loss; extern const Name lower_left; +extern const Name m; extern const Name major_axis; extern const Name make_symmetric; extern const Name mask; @@ -361,6 +381,7 @@ extern const Name off_grid_spiking; extern const Name offset; extern const Name offsets; extern const Name omega; +extern const Name optimizer; extern const Name order; extern const Name origin; extern const Name other; @@ -420,6 +441,8 @@ extern const Name rate_times; extern const Name rate_values; extern const Name ratio_ER_cyt; extern const Name readout_cycle_duration; +extern const Name readout_signal; +extern const Name readout_signal_unnorm; extern const Name receptor_idx; extern const Name receptor_type; extern const Name receptor_types; @@ -435,6 +458,7 @@ extern const Name rectify_rate; extern const Name recv_buffer_size_secondary_events; extern const Name refractory_input; extern const Name registered; +extern const Name regular_spike_arrival; extern const Name relative_amplitude; extern const Name requires_symmetric; extern const Name reset_pattern; @@ -485,6 +509,8 @@ extern const Name stimulus_source; extern const Name stop; extern const Name structural_plasticity_synapses; extern const Name structural_plasticity_update_interval; +extern const Name surrogate_gradient; +extern const Name surrogate_gradient_function; extern const Name synapse_id; extern const Name synapse_label; extern const Name synapse_model; @@ -507,6 +533,7 @@ extern const Name t_ref_remaining; extern const Name t_ref_tot; extern const Name t_spike; extern const Name target; +extern const Name target_signal; extern const Name target_thread; extern const Name targets; extern const Name tau; @@ -534,6 +561,7 @@ extern const Name tau_max; extern const Name tau_minus; extern const Name tau_minus_stdp; extern const Name tau_minus_triplet; +extern const Name tau_m_readout; extern const Name tau_n; extern const Name tau_P; extern const Name tau_plus; @@ -603,6 +631,7 @@ extern const Name time_update; extern const Name times; extern const Name to_do; extern const Name total_num_virtual_procs; +extern const Name type; extern const Name type_id; extern const Name U; @@ -617,6 +646,7 @@ extern const Name upper_right; extern const Name use_compressed_spikes; extern const Name use_wfr; +extern const Name v; extern const Name V_act_NMDA; extern const Name V_clamp; extern const Name V_epsp; @@ -628,6 +658,7 @@ extern const Name V_reset; extern const Name V_T; extern const Name V_T_star; extern const Name V_th; +extern const Name V_th_adapt; extern const Name V_th_alpha_1; extern const Name V_th_alpha_2; extern const Name V_th_max; diff --git a/nestkernel/node.cpp b/nestkernel/node.cpp index d490e40330..8d990c4a86 100644 --- a/nestkernel/node.cpp +++ b/nestkernel/node.cpp @@ -216,6 +216,30 @@ Node::register_stdp_connection( double, double ) throw IllegalConnection( "The target node does not support STDP synapses." ); } +void +Node::register_eprop_connection() +{ + throw IllegalConnection( "The target node does not support eprop synapses." ); +} + +long +Node::get_shift() const +{ + throw IllegalConnection( "The target node is not an e-prop neuron." ); +} + +void +Node::write_update_to_history( const long t_previous_update, const long t_current_update ) +{ + throw IllegalConnection( "The target node is not an e-prop neuron." ); +} + +bool +Node::is_eprop_recurrent_node() const +{ + throw IllegalConnection( "The target node is not an e-prop neuron." ); +} + /** * Default implementation of event handlers just throws * an UnexpectedEvent exception. @@ -398,12 +422,33 @@ Node::sends_secondary_event( DelayedRateConnectionEvent& ) throw IllegalConnection( "The source node does not support delayed rate output." ); } +void +Node::handle( LearningSignalConnectionEvent& ) +{ + throw UnexpectedEvent(); +} + void Node::handle( SICEvent& ) { throw UnexpectedEvent(); } +size_t +Node::handles_test_event( LearningSignalConnectionEvent&, size_t ) +{ + throw IllegalConnection( + "The target node cannot handle learning signal events or" + " synapse is not of type eprop_learning_signal_connection_bsshslm_2020." ); + return invalid_port; +} + +void +Node::sends_secondary_event( LearningSignalConnectionEvent& ) +{ + throw IllegalConnection(); +} + size_t Node::handles_test_event( SICEvent&, size_t ) { @@ -496,6 +541,12 @@ nest::Node::get_tau_syn_in( int ) throw UnexpectedEvent(); } +double +nest::Node::compute_gradient( std::vector< long >&, const long, const long, const double, const bool ) +{ + throw IllegalConnection( "The target node does not support compute_gradient()." ); +} + void Node::event_hook( DSSpikeEvent& e ) { diff --git a/nestkernel/node.h b/nestkernel/node.h index 737ffe4f1f..83d3f33e32 100644 --- a/nestkernel/node.h +++ b/nestkernel/node.h @@ -410,6 +410,7 @@ class Node virtual size_t handles_test_event( InstantaneousRateConnectionEvent&, size_t receptor_type ); virtual size_t handles_test_event( DiffusionConnectionEvent&, size_t receptor_type ); virtual size_t handles_test_event( DelayedRateConnectionEvent&, size_t receptor_type ); + virtual size_t handles_test_event( LearningSignalConnectionEvent&, size_t receptor_type ); virtual size_t handles_test_event( SICEvent&, size_t receptor_type ); /** @@ -453,7 +454,17 @@ class Node virtual void sends_secondary_event( DelayedRateConnectionEvent& re ); /** - * Required to check, if source node may send a SICEvent. + * Required to check if source node may send a LearningSignalConnectionEvent. + * + * This base class implementation throws IllegalConnection + * and needs to be overwritten in the derived class. + * @ingroup event_interface + * @throws IllegalConnection + */ + virtual void sends_secondary_event( LearningSignalConnectionEvent& re ); + + /** + * Required to check if source node may send a SICEvent. * * This base class implementation throws IllegalConnection * and needs to be overwritten in the derived class. @@ -470,6 +481,45 @@ class Node */ virtual void register_stdp_connection( double, double ); + /** + * Initialize the update history and register the eprop synapse. + * + * @throws IllegalConnection + */ + virtual void register_eprop_connection(); + + /** + * Get the number of steps the time-point of the signal has to be shifted to + * place it at the correct location in the e-prop-related histories. + * + * @note Unlike the original e-prop, where signals arise instantaneously, NEST + * considers connection delays. Thus, to reproduce the original results, we + * compensate for the delays and synchronize the signals by shifting the + * history. + * + * @throws IllegalConnection + */ + virtual long get_shift() const; + + /** + * Register current update in the update history and deregister previous update. + * + * @throws IllegalConnection + */ + virtual void write_update_to_history( const long t_previous_update, const long t_current_update ); + + /** + * Return if the node is part of the recurrent network (and thus not a readout neuron). + * + * @note The e-prop synapse calls this function of the target node. If true, + * it skips weight updates within the first interval step of the update + * interval. + * + * @throws IllegalConnection + */ + virtual bool is_eprop_recurrent_node() const; + + /** * Handle incoming spike events. * @@ -588,8 +638,18 @@ class Node */ virtual void handle( DelayedRateConnectionEvent& e ); + /** + * Handler for learning signal connection events. + * + * @see handle(thread, LearningSignalConnectionEvent&) + * @ingroup event_interface + * @throws UnexpectedEvent + */ + virtual void handle( LearningSignalConnectionEvent& e ); + /** * Handler for slow inward current events (SICEvents). + * * @see handle(thread,SICEvent&) * @ingroup event_interface * @throws UnexpectedEvent @@ -741,6 +801,19 @@ class Node virtual double get_tau_syn_ex( int comp ); virtual double get_tau_syn_in( int comp ); + /** + * Compute gradient change for eprop synapses. + * + * This method is called from an eprop synapse on the eprop target neuron and returns the change in gradient. + * + * @params presyn_isis is cleared during call + */ + virtual double compute_gradient( std::vector< long >& presyn_isis, + const long t_previous_update, + const long t_previous_trigger_spike, + const double kappa, + const bool average_gradient ); + /** * Modify Event object parameters during event delivery. * diff --git a/nestkernel/node_manager.cpp b/nestkernel/node_manager.cpp index d3155c701d..7c3ca760bc 100644 --- a/nestkernel/node_manager.cpp +++ b/nestkernel/node_manager.cpp @@ -737,6 +737,7 @@ NodeManager::check_wfr_use() InstantaneousRateConnectionEvent::set_coeff_length( kernel().connection_manager.get_min_delay() ); DelayedRateConnectionEvent::set_coeff_length( kernel().connection_manager.get_min_delay() ); DiffusionConnectionEvent::set_coeff_length( kernel().connection_manager.get_min_delay() ); + LearningSignalConnectionEvent::set_coeff_length( kernel().connection_manager.get_min_delay() ); SICEvent::set_coeff_length( kernel().connection_manager.get_min_delay() ); } diff --git a/nestkernel/proxynode.cpp b/nestkernel/proxynode.cpp index 596d4ad5c2..af4b624a57 100644 --- a/nestkernel/proxynode.cpp +++ b/nestkernel/proxynode.cpp @@ -73,6 +73,12 @@ proxynode::sends_secondary_event( DelayedRateConnectionEvent& re ) kernel().model_manager.get_node_model( get_model_id() )->sends_secondary_event( re ); } +void +proxynode::sends_secondary_event( LearningSignalConnectionEvent& re ) +{ + kernel().model_manager.get_node_model( get_model_id() )->sends_secondary_event( re ); +} + void proxynode::sends_secondary_event( SICEvent& sic ) { diff --git a/nestkernel/proxynode.h b/nestkernel/proxynode.h index 0726988b7f..4207e8763e 100644 --- a/nestkernel/proxynode.h +++ b/nestkernel/proxynode.h @@ -97,6 +97,8 @@ class proxynode : public Node void sends_secondary_event( DelayedRateConnectionEvent& ) override; + void sends_secondary_event( LearningSignalConnectionEvent& ) override; + void sends_secondary_event( SICEvent& ) override; void diff --git a/nestkernel/secondary_event.h b/nestkernel/secondary_event.h index 3d5f339fd3..00676e68f1 100644 --- a/nestkernel/secondary_event.h +++ b/nestkernel/secondary_event.h @@ -339,6 +339,7 @@ class DelayedRateConnectionEvent : public DataSecondaryEvent< double, DelayedRat DelayedRateConnectionEvent* clone() const override; }; + /** * Event for diffusion connections (rate model connections for the * siegert_neuron). The event transmits the rate to the connected neurons. @@ -375,6 +376,22 @@ class DiffusionConnectionEvent : public DataSecondaryEvent< double, DiffusionCon double get_diffusion_factor() const; }; +/** + * Event for learning signal connections. The event transmits + * the learning signal to the connected neurons. + */ +class LearningSignalConnectionEvent : public DataSecondaryEvent< double, LearningSignalConnectionEvent > +{ + +public: + LearningSignalConnectionEvent() + { + } + + void operator()() override; + LearningSignalConnectionEvent* clone() const override; +}; + template < typename DataType, typename Subclass > inline DataType DataSecondaryEvent< DataType, Subclass >::get_coeffvalue( std::vector< unsigned int >::iterator& pos ) @@ -426,6 +443,12 @@ DiffusionConnectionEvent::get_diffusion_factor() const return diffusion_factor_; } +inline LearningSignalConnectionEvent* +LearningSignalConnectionEvent::clone() const +{ + return new LearningSignalConnectionEvent( *this ); +} + /** * Event for slow inward current (SIC) connections between astrocytes and neurons. * diff --git a/nestkernel/simulation_manager.cpp b/nestkernel/simulation_manager.cpp index d072e27beb..2bcdf2802a 100644 --- a/nestkernel/simulation_manager.cpp +++ b/nestkernel/simulation_manager.cpp @@ -62,6 +62,9 @@ nest::SimulationManager::SimulationManager() , update_time_limit_( std::numeric_limits< double >::infinity() ) , min_update_time_( std::numeric_limits< double >::infinity() ) , max_update_time_( -std::numeric_limits< double >::infinity() ) + , eprop_update_interval_( 1000. ) + , eprop_learning_window_( 1000. ) + , eprop_reset_neurons_on_update_( true ) { } @@ -412,6 +415,39 @@ nest::SimulationManager::set_status( const DictionaryDatum& d ) update_time_limit_ = t_new; } + + // eprop update interval + double eprop_update_interval_new = 0.0; + if ( updateValue< double >( d, names::eprop_update_interval, eprop_update_interval_new ) ) + { + if ( eprop_update_interval_new <= 0 ) + { + LOG( M_ERROR, "SimulationManager::set_status", "eprop_update_interval > 0 required." ); + throw KernelException(); + } + + eprop_update_interval_ = eprop_update_interval_new; + } + + // eprop learning window + double eprop_learning_window_new = 0.0; + if ( updateValue< double >( d, names::eprop_learning_window, eprop_learning_window_new ) ) + { + if ( eprop_learning_window_new <= 0 ) + { + LOG( M_ERROR, "SimulationManager::set_status", "eprop_learning_window > 0 required." ); + throw KernelException(); + } + if ( eprop_learning_window_new > eprop_update_interval_ ) + { + LOG( M_ERROR, "SimulationManager::set_status", "eprop_learning_window <= eprop_update_interval required." ); + throw KernelException(); + } + + eprop_learning_window_ = eprop_learning_window_new; + } + + updateValue< bool >( d, names::eprop_reset_neurons_on_update, eprop_reset_neurons_on_update_ ); } void @@ -451,6 +487,9 @@ nest::SimulationManager::get_status( DictionaryDatum& d ) def< double >( d, names::time_deliver_spike_data, sw_deliver_spike_data_.elapsed() ); def< double >( d, names::time_deliver_secondary_data, sw_deliver_secondary_data_.elapsed() ); #endif + def< double >( d, names::eprop_update_interval, eprop_update_interval_ ); + def< double >( d, names::eprop_learning_window, eprop_learning_window_ ); + def< bool >( d, names::eprop_reset_neurons_on_update, eprop_reset_neurons_on_update_ ); } void @@ -824,43 +863,48 @@ nest::SimulationManager::update_() // and invalid markers have not been properly set in send buffers. if ( slice_ > 0 and from_step_ == 0 ) { - - if ( kernel().connection_manager.has_primary_connections() ) + // Deliver secondary events before primary events + // + // Delivering secondary events ahead of primary events ensures that LearningSignalConnectionEvents + // reach target neurons before spikes are propagated through eprop synapses. + // This sequence safeguards the gradient computation from missing critical information + // from the time step preceding the arrival of the spike triggering the weight update. + if ( kernel().connection_manager.secondary_connections_exist() ) { #ifdef TIMER_DETAILED if ( tid == 0 ) { - sw_deliver_spike_data_.start(); + sw_deliver_secondary_data_.start(); } #endif - // Deliver spikes from receive buffer to ring buffers. - kernel().event_delivery_manager.deliver_events( tid ); - + kernel().event_delivery_manager.deliver_secondary_events( tid, false ); #ifdef TIMER_DETAILED if ( tid == 0 ) { - sw_deliver_spike_data_.stop(); + sw_deliver_secondary_data_.stop(); } #endif } - if ( kernel().connection_manager.secondary_connections_exist() ) + + if ( kernel().connection_manager.has_primary_connections() ) { #ifdef TIMER_DETAILED if ( tid == 0 ) { - sw_deliver_secondary_data_.start(); + sw_deliver_spike_data_.start(); } #endif - kernel().event_delivery_manager.deliver_secondary_events( tid, false ); + // Deliver spikes from receive buffer to ring buffers. + kernel().event_delivery_manager.deliver_events( tid ); + #ifdef TIMER_DETAILED if ( tid == 0 ) { - sw_deliver_secondary_data_.stop(); + sw_deliver_spike_data_.stop(); } #endif } - #ifdef HAVE_MUSIC // advance the time of music by one step (min_delay * h) must // be done after deliver_events_() since it calls diff --git a/nestkernel/simulation_manager.h b/nestkernel/simulation_manager.h index b71a4767b8..83a37cac31 100644 --- a/nestkernel/simulation_manager.h +++ b/nestkernel/simulation_manager.h @@ -185,6 +185,10 @@ class SimulationManager : public ManagerInterface */ virtual void reset_timers_for_dynamics(); + Time get_eprop_update_interval() const; + Time get_eprop_learning_window() const; + bool get_eprop_reset_neurons_on_update() const; + private: void call_update_(); //!< actually run simulation, aka wrap update_ void update_(); //! actually perform simulation @@ -235,6 +239,10 @@ class SimulationManager : public ManagerInterface Stopwatch sw_deliver_spike_data_; Stopwatch sw_deliver_secondary_data_; #endif + + double eprop_update_interval_; + double eprop_learning_window_; + bool eprop_reset_neurons_on_update_; }; inline Time const& @@ -329,6 +337,25 @@ SimulationManager::get_wfr_interpolation_order() const { return wfr_interpolation_order_; } + +inline Time +SimulationManager::get_eprop_update_interval() const +{ + return Time::ms( eprop_update_interval_ ); +} + +inline Time +SimulationManager::get_eprop_learning_window() const +{ + return Time::ms( eprop_learning_window_ ); +} + +inline bool +SimulationManager::get_eprop_reset_neurons_on_update() const +{ + return eprop_reset_neurons_on_update_; +} + } diff --git a/pynest/examples/eprop_plasticity/README.rst b/pynest/examples/eprop_plasticity/README.rst new file mode 100644 index 0000000000..99c49c1e94 --- /dev/null +++ b/pynest/examples/eprop_plasticity/README.rst @@ -0,0 +1,26 @@ +E-prop plasticity examples +========================== + + +.. image:: ../../../../pynest/examples/eprop_plasticity/eprop_supervised_regression_schematic_sine-waves.png + +Eligibility propagation (e-prop) [1]_ is a three-factor learning rule for spiking neural networks +that approximates backpropagation through time. The original TensorFlow implementation of e-prop +was demonstrated, among others, on a supervised regression task to generate temporal patterns and a +supervised classification task to accumulate evidence [2]_. Here, you find tutorials on how to +reproduce these two tasks as well as two more advanced regression tasks using the NEST implementation +of e-prop [3]_ and how to visualize the simulation recordings. + +References +---------- + +.. [1] Bellec G, Scherr F, Subramoney F, Hajek E, Salaj D, Legenstein R, + Maass W (2020). A solution to the learning dilemma for recurrent + networks of spiking neurons. Nature Communications, 11:3625. + https://doi.org/10.1038/s41467-020-17236-y + +.. [2] https://github.com/IGITUGraz/eligibility_propagation/blob/master/Figure_3_and_S7_e_prop_tutorials/ + +.. [3] Korcsak-Gorzo A, Stapmanns J, Espinoza Valverde JA, Dahmen D, + van Albada SJ, Bolten M, Diesmann M. Event-based implementation of + eligibility propagation (in preparation) diff --git a/pynest/examples/eprop_plasticity/chaos_handwriting.txt b/pynest/examples/eprop_plasticity/chaos_handwriting.txt new file mode 100644 index 0000000000..659b30882b --- /dev/null +++ b/pynest/examples/eprop_plasticity/chaos_handwriting.txt @@ -0,0 +1,228 @@ +0.648636 -0.77279 +1.476263 -1.28962 +1.291653 -2.41572 +-0.03011 -0.18367 +-0.26493 -0.26199 +-0.406622 -0.38268 +-0.289118 -0.24627 +-0.55635 -0.52727 +-0.885031 -0.71755 +-2.310027 -1.33729 +-3.308679 -1.83921 +-6.075538 -1.88952 +-2.167289 -0.0394 +-5.050104 0.76205 +-7.008399 1.48292 +-2.978087 1.09626 +-5.567521 2.10051 +-7.965183 4.16174 +-2.804973 2.41137 +-4.275307 5.4813 +-4.592527 9.13667 +-0.0834 0.96108 +-0.189203 1.94812 +0 2.89407 +0.467408 2.33688 +1.202457 4.28835 +3.229121 5.62072 +1.894848 1.24572 +4.265035 1.13402 +6.386488 0.7893 +3.508314 -0.57007 +6.137894 -2.39878 +9.161156 -4.23347 +2.830811 -1.71791 +5.596469 -3.54095 +8.395724 -5.30981 +1.793285 -1.13319 +3.617547 -2.2186 +5.381889 -3.39634 +1.187167 -0.79247 +2.360761 -1.61101 +3.468307 -2.5114 +1.559298 -1.26765 +3.080141 -2.59613 +3.707529 -4.59225 +0.10554 -0.33578 +0.04875 -0.82607 +-0.239204 -1.02848 +-0.621041 -0.43655 +-1.417551 -0.66724 +-2.176669 -0.66971 +-2.925403 -0.01 +-5.830075 -0.18178 +-7.22367 2.79841 +-0.599086 1.28113 +-0.341487 5.94904 +-0.287039 6.96014 +0.232207 4.31215 +0.418037 5.24874 +0.789348 9.1845 +0.0398 0.42188 +0.04311 0.84717 +0.09566 1.26766 +0.02077 0.16612 +-0.0552 -0.34422 +0 -0.50228 +0.124637 -0.35687 +0.286247 -0.71249 +0.526243 -1.00455 +1.483873 -1.80572 +2.431722 -2.46151 +4.305488 -3.82688 +0.722164 -0.52622 +1.43401 -1.07167 +2.200606 -1.53075 +0.688289 -0.41219 +1.351015 -1.1057 +2.152753 -1.07632 +2.137701 0.0783 +2.534343 1.41198 +3.276954 2.84626 +0.674608 1.30294 +1.158806 2.68923 +2.391939 3.63553 +0.438095 0.33619 +1.027629 0.53802 +1.578692 0.50227 +2.012747 -0.13055 +3.665871 -1.21889 +5.381874 -2.10478 +0.868188 -0.44819 +1.722043 -0.9258 +2.607228 -1.33941 +1.50658 -0.70397 +3.911214 -1.59131 +5.429708 -2.15261 +0.707305 -0.26146 +1.40172 -0.56553 +2.128834 -0.76538 +1.782091 -0.48984 +3.103039 -0.70296 +4.879584 -0.81321 +0.501406 -0.0311 +1.006692 -0.0702 +1.506924 -0.0239 +0.367773 0.034 +0.733835 0.12488 +1.076351 0.26308 +0.200514 0.0809 +0.369618 0.40622 +0.334872 0.43054 +-0.07864 0.0551 +-0.197313 0.0581 +-0.287004 0.0239 +-0.526534 -0.20091 +-1.008692 -0.50555 +-1.53086 -0.71754 +-1.105141 -0.44868 +-2.448894 -0.9588 +-3.683593 -0.64579 +-0.533778 0.13531 +-1.033013 0.49229 +-1.363424 0.93281 +-1.412492 1.8832 +-1.747298 3.43413 +-2.080967 5.66856 +-0.124092 0.83098 +-0.443021 1.70263 +-0.215307 2.51138 +0.08223 0.29205 +0.671433 0.14095 +0.90895 -0.0478 +0.45008 -0.35773 +0.642403 -0.95372 +0.956783 -1.43508 +0.45107 -0.69066 +0.888341 -1.39026 +1.33949 -2.08086 +0.194707 -0.29806 +0.399095 -0.58967 +0.597975 -0.88496 +0.06424 -0.0954 +0.07651 -0.28155 +0.19137 -0.28702 +0.5256 -0.025 +1.797592 0.7486 +2.009235 0.98064 +0.610976 0.66988 +1.069277 1.46791 +1.530858 2.24828 +0.429774 0.72659 +0.666171 1.9912 +1.554761 2.46357 +1.028877 0.54695 +2.518381 -0.0458 +3.348731 -0.50229 +1.8139 -0.9973 +3.49507 -2.21975 +5.21442 -3.37243 +2.63722 -1.76803 +3.19628 -2.73439 +6.243 -3.46811 +0.49706 -0.11971 +1.02917 -0.003 +1.53083 0.0957 +0.47996 0.0944 +0.92178 0.32821 +1.38732 0.47838 +0.0304 0.01 +0.11826 0.0226 +0.0957 0 +-0.19726 -0.19726 +-1.6926 -1.44786 +-1.77006 -1.50684 +-0.40998 -0.31212 +-1.49627 -1.23755 +-2.20057 -0.90889 +-0.73769 0.34423 +-1.10053 2.67884 +-1.12422 2.98975 +-0.0947 1.24238 +-0.42131 5.26097 +2.00923 5.23804 +1.5887 -0.015 +3.04225 -3.04059 +3.77927 -3.9943 +0.0617 -0.0798 +0.20031 -0.0442 +0.28703 -0.0957 +0.17139 -0.10176 +0.35215 -0.20451 +0.47838 -0.35877 +0.60054 -0.73395 +0.6348 -1.91165 +1.60263 -2.36788 +0.27414 -0.12924 +0.61629 -0.10284 +0.90891 -0.0239 +1.64322 0.44317 +3.20826 1.15116 +4.85565 1.57859 +1.09496 0.28409 +2.21874 0.49801 +3.34876 0.5501 +1.13378 0.0523 +2.27804 -0.0465 +3.39654 -0.23917 +1.6817 -0.28973 +4.27095 -1.65147 +6.19515 -0.83712 +1.22228 0.51728 +2.26271 1.41244 +3.25301 2.29611 +0.80207 0.7157 +1.4609 1.58551 +2.081 2.46357 +0.83897 1.18794 +1.83665 2.6557 +1.00462 4.20955 +-0.55569 1.03777 +-2.74372 1.38641 +-3.61183 1.43509 +-0.61312 0.0344 +-3.79916 0.10062 +-4.59254 -0.81321 +-0.65597 -0.75555 +-0.36283 -1.23904 +-0.0957 -1.81778 diff --git a/pynest/examples/eprop_plasticity/eprop_supervised_classification_evidence-accumulation.py b/pynest/examples/eprop_plasticity/eprop_supervised_classification_evidence-accumulation.py new file mode 100644 index 0000000000..f0facff028 --- /dev/null +++ b/pynest/examples/eprop_plasticity/eprop_supervised_classification_evidence-accumulation.py @@ -0,0 +1,789 @@ +# -*- coding: utf-8 -*- +# +# eprop_supervised_classification_evidence-accumulation.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +r""" +Tutorial on learning to accumulate evidence with e-prop +------------------------------------------------------- + +Training a classification model using supervised e-prop plasticity to accumulate evidence. + +Description +~~~~~~~~~~~ + +This script demonstrates supervised learning of a classification task with the eligibility propagation (e-prop) +plasticity mechanism by Bellec et al. [1]_. + +This type of learning is demonstrated at the proof-of-concept task in [1]_. We based this script on their +TensorFlow script given in [2]_. + +The task, a so-called evidence accumulation task, is inspired by behavioral tasks, where a lab animal (e.g., a +mouse) runs along a track, gets cues on the left and right, and has to decide at the end of the track between +taking a left and a right turn of which one is correct. After a number of iterations, the animal is able to +infer the underlying rationale of the task. Here, the solution is to turn to the side in which more cues were +presented. + +.. image:: ../../../../pynest/examples/eprop_plasticity/eprop_supervised_classification_schematic_evidence-accumulation.png + :width: 70 % + :alt: See Figure 1 below. + :align: center + +Learning in the neural network model is achieved by optimizing the connection weights with e-prop plasticity. +This plasticity rule requires a specific network architecture depicted in Figure 1. The neural network model +consists of a recurrent network that receives input from Poisson generators and projects onto two readout +neurons - one for the left and one for the right turn at the end. The input neuron population consists of four +groups: one group providing background noise of a specific rate for some base activity throughout the +experiment, one group providing the input spikes of the left cues and one group providing them for the right +cues, and a last group defining the recall window, in which the network has to decide. The readout neuron +compares the network signal :math:`\pi_k` with the teacher target signal :math:`\pi_k^*`, which it receives from +a rate generator. Since the decision is at the end and all the cues are relevant, the network has to keep the +cues in memory. Additional adaptive neurons in the network enable this memory. The network's training error is +assessed by employing a cross-entropy error loss. + +Details on the event-based NEST implementation of e-prop can be found in [3]_. + +References +~~~~~~~~~~ + +.. [1] Bellec G, Scherr F, Subramoney F, Hajek E, Salaj D, Legenstein R, Maass W (2020). A solution to the + learning dilemma for recurrent networks of spiking neurons. Nature Communications, 11:3625. + https://doi.org/10.1038/s41467-020-17236-y + +.. [2] https://github.com/IGITUGraz/eligibility_propagation/blob/master/Figure_3_and_S7_e_prop_tutorials/tutorial_evidence_accumulation_with_alif.py + +.. [3] Korcsak-Gorzo A, Stapmanns J, Espinoza Valverde JA, Dahmen D, van Albada SJ, Bolten M, Diesmann M. + Event-based implementation of eligibility propagation (in preparation) +""" # pylint: disable=line-too-long # noqa: E501 + +# %% ########################################################################################################### +# Import libraries +# ~~~~~~~~~~~~~~~~ +# We begin by importing all libraries required for the simulation, analysis, and visualization. + +import matplotlib as mpl +import matplotlib.pyplot as plt +import nest +import numpy as np +from cycler import cycler +from IPython.display import Image + +# %% ########################################################################################################### +# Schematic of network architecture +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# This figure, identical to the one in the description, shows the required network architecture in the center, +# the input and output of the pattern generation task above, and lists of the required NEST device, neuron, and +# synapse models below. The connections that must be established are numbered 1 to 7. + +try: + Image(filename="./eprop_supervised_classification_schematic_evidence-accumulation.png") +except Exception: + pass + +# %% ########################################################################################################### +# Setup +# ~~~~~ + +# %% ########################################################################################################### +# Initialize random generator +# ........................... +# We seed the numpy random generator, which will generate random initial weights as well as random input and +# output. + +rng_seed = 1 # numpy random seed +np.random.seed(rng_seed) # fix numpy random seed + +# %% ########################################################################################################### +# Define timing of task +# ..................... +# The task's temporal structure is then defined, once as time steps and once as durations in milliseconds. +# Using a batch size larger than one aids the network in generalization, facilitating the solution to this task. +# The original number of iterations requires distributed computing. + +n_batch = 1 # batch size, 64 in reference [2], 32 in the README to reference [2] +n_iter = 5 # number of iterations, 2000 in reference [2], 50 with n_batch 32 converges + +n_input_symbols = 4 # number of input populations, e.g. 4 = left, right, recall, noise +n_cues = 7 # number of cues given before decision +prob_group = 0.3 # probability with which one input group is present + +steps = { + "cue": 100, # time steps in one cue presentation + "spacing": 50, # time steps of break between two cues + "bg_noise": 1050, # time steps of background noise + "recall": 150, # time steps of recall +} + +steps["cues"] = n_cues * (steps["cue"] + steps["spacing"]) # time steps of all cues +steps["sequence"] = steps["cues"] + steps["bg_noise"] + steps["recall"] # time steps of one full sequence +steps["learning_window"] = steps["recall"] # time steps of window with non-zero learning signals +steps["task"] = n_iter * n_batch * steps["sequence"] # time steps of task + +steps.update( + { + "offset_gen": 1, # offset since generator signals start from time step 1 + "delay_in_rec": 1, # connection delay between input and recurrent neurons + "delay_rec_out": 1, # connection delay between recurrent and output neurons + "delay_out_norm": 1, # connection delay between output neurons for normalization + "extension_sim": 1, # extra time step to close right-open simulation time interval in Simulate() + } +) + +steps["delays"] = steps["delay_in_rec"] + steps["delay_rec_out"] + steps["delay_out_norm"] # time steps of delays + +steps["total_offset"] = steps["offset_gen"] + steps["delays"] # time steps of total offset + +steps["sim"] = steps["task"] + steps["total_offset"] + steps["extension_sim"] # time steps of simulation + +duration = {"step": 1.0} # ms, temporal resolution of the simulation + +duration.update({key: value * duration["step"] for key, value in steps.items()}) # ms, durations + +# %% ########################################################################################################### +# Set up simulation +# ................. +# As last step of the setup, we reset the NEST kernel to remove all existing NEST simulation settings and +# objects and set some NEST kernel parameters, some of which are e-prop-related. + +params_setup = { + "eprop_learning_window": duration["learning_window"], + "eprop_reset_neurons_on_update": True, # if True, reset dynamic variables at start of each update interval + "eprop_update_interval": duration["sequence"], # ms, time interval for updating the synaptic weights + "print_time": False, # if True, print time progress bar during simulation, set False if run as code cell + "resolution": duration["step"], + "total_num_virtual_procs": 1, # number of virtual processes, set in case of distributed computing +} + +#################### + +nest.ResetKernel() +nest.set(**params_setup) + +# %% ########################################################################################################### +# Create neurons +# ~~~~~~~~~~~~~~ +# We proceed by creating a certain number of input, recurrent, and readout neurons and setting their parameters. +# Additionally, we already create an input spike generator and an output target rate generator, which we will +# configure later. Within the recurrent network, alongside a population of regular neurons, we introduce a +# population of adaptive neurons, to enhance the network's memory retention. + +n_in = 40 # number of input neurons +n_ad = 50 # number of adaptive neurons +n_reg = 50 # number of regular neurons +n_rec = n_ad + n_reg # number of recurrent neurons +n_out = 2 # number of readout neurons + + +params_nrn_reg = { + "C_m": 1.0, # pF, membrane capacitance - takes effect only if neurons get current input (here not the case) + "c_reg": 2.0, # firing rate regularization scaling - double the TF c_reg for technical reasons + "E_L": 0.0, # mV, leak / resting membrane potential + "f_target": 10.0, # spikes/s, target firing rate for firing rate regularization + "gamma": 0.3, # scaling of the pseudo derivative + "I_e": 0.0, # pA, external current input + "regular_spike_arrival": True, # If True, input spikes arrive at end of time step, if False at beginning + "surrogate_gradient_function": "piecewise_linear", # surrogate gradient / pseudo-derivative function + "t_ref": 5.0, # ms, duration of refractory period + "tau_m": 20.0, # ms, membrane time constant + "V_m": 0.0, # mV, initial value of the membrane voltage + "V_th": 0.6, # mV, spike threshold membrane voltage +} + +params_nrn_ad = { + "adapt_tau": 2000.0, # ms, time constant of adaptive threshold + "adaptation": 0.0, # initial value of the spike threshold adaptation + "C_m": 1.0, + "c_reg": 2.0, + "E_L": 0.0, + "f_target": 10.0, + "gamma": 0.3, + "I_e": 0.0, + "regular_spike_arrival": True, + "surrogate_gradient_function": "piecewise_linear", + "t_ref": 5.0, + "tau_m": 20.0, + "V_m": 0.0, + "V_th": 0.6, +} + +params_nrn_ad["adapt_beta"] = 1.7 * ( + (1.0 - np.exp(-duration["step"] / params_nrn_ad["adapt_tau"])) + / (1.0 - np.exp(-duration["step"] / params_nrn_ad["tau_m"])) +) # prefactor of adaptive threshold + +params_nrn_out = { + "C_m": 1.0, + "E_L": 0.0, + "I_e": 0.0, + "loss": "cross_entropy", # loss function + "regular_spike_arrival": False, + "tau_m": 20.0, + "V_m": 0.0, +} + +#################### + +# Intermediate parrot neurons required between input spike generators and recurrent neurons, +# since devices cannot establish plastic synapses for technical reasons + +gen_spk_in = nest.Create("spike_generator", n_in) +nrns_in = nest.Create("parrot_neuron", n_in) + +# The suffix _bsshslm_2020 follows the NEST convention to indicate in the model name the paper +# that introduced it by the first letter of the authors' last names and the publication year. + +nrns_reg = nest.Create("eprop_iaf_bsshslm_2020", n_reg, params_nrn_reg) +nrns_ad = nest.Create("eprop_iaf_adapt_bsshslm_2020", n_ad, params_nrn_ad) +nrns_out = nest.Create("eprop_readout_bsshslm_2020", n_out, params_nrn_out) +gen_rate_target = nest.Create("step_rate_generator", n_out) + +nrns_rec = nrns_reg + nrns_ad + +# %% ########################################################################################################### +# Create recorders +# ~~~~~~~~~~~~~~~~ +# We also create recorders, which, while not required for the training, will allow us to track various dynamic +# variables of the neurons, spikes, and changes in synaptic weights. To save computing time and memory, the +# recorders, the recorded variables, neurons, and synapses can be limited to the ones relevant to the +# experiment, and the recording interval can be increased (see the documentation on the specific recorders). By +# default, recordings are stored in memory but can also be written to file. + +n_record = 1 # number of neurons per type to record dynamic variables from - this script requires n_record >= 1 +n_record_w = 3 # number of senders and targets to record weights from - this script requires n_record_w >=1 + +if n_record == 0 or n_record_w == 0: + raise ValueError("n_record and n_record_w >= 1 required") + +params_mm_reg = { + "interval": duration["step"], # interval between two recorded time points + "record_from": ["V_m", "surrogate_gradient", "learning_signal"], # dynamic variables to record + "start": duration["offset_gen"] + duration["delay_in_rec"], # start time of recording + "stop": duration["offset_gen"] + duration["delay_in_rec"] + duration["task"], # stop time of recording +} + +params_mm_ad = { + "interval": duration["step"], + "record_from": params_mm_reg["record_from"] + ["V_th_adapt", "adaptation"], + "start": duration["offset_gen"] + duration["delay_in_rec"], + "stop": duration["offset_gen"] + duration["delay_in_rec"] + duration["task"], +} + +params_mm_out = { + "interval": duration["step"], + "record_from": ["V_m", "readout_signal", "readout_signal_unnorm", "target_signal", "error_signal"], + "start": duration["total_offset"], + "stop": duration["total_offset"] + duration["task"], +} + +params_wr = { + "senders": nrns_in[:n_record_w] + nrns_rec[:n_record_w], # limit senders to subsample weights to record + "targets": nrns_rec[:n_record_w] + nrns_out, # limit targets to subsample weights to record from + "start": duration["total_offset"], + "stop": duration["total_offset"] + duration["task"], +} + +params_sr = { + "start": duration["total_offset"], + "stop": duration["total_offset"] + duration["task"], +} + +#################### + +mm_reg = nest.Create("multimeter", params_mm_reg) +mm_ad = nest.Create("multimeter", params_mm_ad) +mm_out = nest.Create("multimeter", params_mm_out) +sr = nest.Create("spike_recorder", params_sr) +wr = nest.Create("weight_recorder", params_wr) + +nrns_reg_record = nrns_reg[:n_record] +nrns_ad_record = nrns_ad[:n_record] + +# %% ########################################################################################################### +# Create connections +# ~~~~~~~~~~~~~~~~~~ +# Now, we define the connectivity and set up the synaptic parameters, with the synaptic weights drawn from +# normal distributions. After these preparations, we establish the enumerated connections of the core network, +# as well as additional connections to the recorders. + +params_conn_all_to_all = {"rule": "all_to_all", "allow_autapses": False} +params_conn_one_to_one = {"rule": "one_to_one"} + + +def calculate_glorot_dist(fan_in, fan_out): + glorot_scale = 1.0 / max(1.0, (fan_in + fan_out) / 2.0) + glorot_limit = np.sqrt(3.0 * glorot_scale) + glorot_distribution = np.random.uniform(low=-glorot_limit, high=glorot_limit, size=(fan_in, fan_out)) + return glorot_distribution + + +dtype_weights = np.float32 # data type of weights - for reproducing TF results set to np.float32 +weights_in_rec = np.array(np.random.randn(n_in, n_rec).T / np.sqrt(n_in), dtype=dtype_weights) +weights_rec_rec = np.array(np.random.randn(n_rec, n_rec).T / np.sqrt(n_rec), dtype=dtype_weights) +np.fill_diagonal(weights_rec_rec, 0.0) # since no autapses set corresponding weights to zero +weights_rec_out = np.array(calculate_glorot_dist(n_rec, n_out).T, dtype=dtype_weights) +weights_out_rec = np.array(np.random.randn(n_rec, n_out), dtype=dtype_weights) + +params_common_syn_eprop = { + "optimizer": { + "type": "adam", # algorithm to optimize the weights + "batch_size": n_batch, + "beta_1": 0.9, # exponential decay rate for 1st moment estimate of Adam optimizer + "beta_2": 0.999, # exponential decay rate for 2nd moment raw estimate of Adam optimizer + "epsilon": 1e-8, # small numerical stabilization constant of Adam optimizer + "eta": 5e-3, # learning rate + "Wmin": -100.0, # pA, minimal limit of the synaptic weights + "Wmax": 100.0, # pA, maximal limit of the synaptic weights + }, + "average_gradient": True, # if True, average the gradient over the learning window + "weight_recorder": wr, +} + +params_syn_base = { + "synapse_model": "eprop_synapse_bsshslm_2020", + "delay": duration["step"], # ms, dendritic delay + "tau_m_readout": params_nrn_out["tau_m"], # ms, for technical reasons pass readout neuron membrane time constant +} + +params_syn_in = params_syn_base.copy() +params_syn_in["weight"] = weights_in_rec # pA, initial values for the synaptic weights + +params_syn_rec = params_syn_base.copy() +params_syn_rec["weight"] = weights_rec_rec + +params_syn_out = params_syn_base.copy() +params_syn_out["weight"] = weights_rec_out + + +params_syn_feedback = { + "synapse_model": "eprop_learning_signal_connection_bsshslm_2020", + "delay": duration["step"], + "weight": weights_out_rec, +} + +params_syn_out_out = { + "synapse_model": "rate_connection_delayed", + "delay": duration["step"], + "receptor_type": 1, # receptor type of readout neuron to receive other readout neuron's signals for softmax + "weight": 1.0, # pA, weight 1.0 required for correct softmax computation for technical reasons +} + +params_syn_rate_target = { + "synapse_model": "rate_connection_delayed", + "delay": duration["step"], + "receptor_type": 2, # receptor type over which readout neuron receives target signal +} + +params_syn_static = { + "synapse_model": "static_synapse", + "delay": duration["step"], +} + +params_init_optimizer = { + "optimizer": { + "m": 0.0, # initial 1st moment estimate m of Adam optimizer + "v": 0.0, # initial 2nd moment raw estimate v of Adam optimizer + } +} + +#################### + +nest.SetDefaults("eprop_synapse_bsshslm_2020", params_common_syn_eprop) + +nest.Connect(gen_spk_in, nrns_in, params_conn_one_to_one, params_syn_static) # connection 1 +nest.Connect(nrns_in, nrns_rec, params_conn_all_to_all, params_syn_in) # connection 2 +nest.Connect(nrns_rec, nrns_rec, params_conn_all_to_all, params_syn_rec) # connection 3 +nest.Connect(nrns_rec, nrns_out, params_conn_all_to_all, params_syn_out) # connection 4 +nest.Connect(nrns_out, nrns_rec, params_conn_all_to_all, params_syn_feedback) # connection 5 +nest.Connect(gen_rate_target, nrns_out, params_conn_one_to_one, params_syn_rate_target) # connection 6 +nest.Connect(nrns_out, nrns_out, params_conn_all_to_all, params_syn_out_out) # connection 7 + +nest.Connect(nrns_in + nrns_rec, sr, params_conn_all_to_all, params_syn_static) + +nest.Connect(mm_reg, nrns_reg_record, params_conn_all_to_all, params_syn_static) +nest.Connect(mm_ad, nrns_ad_record, params_conn_all_to_all, params_syn_static) +nest.Connect(mm_out, nrns_out, params_conn_all_to_all, params_syn_static) + +# After creating the connections, we can individually initialize the optimizer's +# dynamic variables for single synapses (here exemplarily for two connections). + +nest.GetConnections(nrns_rec[0], nrns_rec[1:3]).set([params_init_optimizer] * 2) + +# %% ########################################################################################################### +# Create input and output +# ~~~~~~~~~~~~~~~~~~~~~~~ +# We generate the input as four neuron populations, two producing the left and right cues, respectively, one the +# recall signal and one the background input throughout the task. The sequence of cues is drawn with a +# probability that favors one side. For each such sequence, the favored side, the solution or target, is +# assigned randomly to the left or right. + + +def generate_evidence_accumulation_input_output( + n_batch, n_in, prob_group, input_spike_prob, n_cues, n_input_symbols, steps +): + n_pop_nrn = n_in // n_input_symbols + + prob_choices = np.array([prob_group, 1 - prob_group], dtype=np.float32) + idx = np.random.choice([0, 1], n_batch) + probs = np.zeros((n_batch, 2), dtype=np.float32) + probs[:, 0] = prob_choices[idx] + probs[:, 1] = prob_choices[1 - idx] + + batched_cues = np.zeros((n_batch, n_cues), dtype=int) + for b_idx in range(n_batch): + batched_cues[b_idx, :] = np.random.choice([0, 1], n_cues, p=probs[b_idx]) + + input_spike_probs = np.zeros((n_batch, steps["sequence"], n_in)) + + for b_idx in range(n_batch): + for c_idx in range(n_cues): + cue = batched_cues[b_idx, c_idx] + + step_start = c_idx * (steps["cue"] + steps["spacing"]) + steps["spacing"] + step_stop = step_start + steps["cue"] + + pop_nrn_start = cue * n_pop_nrn + pop_nrn_stop = pop_nrn_start + n_pop_nrn + + input_spike_probs[b_idx, step_start:step_stop, pop_nrn_start:pop_nrn_stop] = input_spike_prob + + input_spike_probs[:, -steps["recall"] :, 2 * n_pop_nrn : 3 * n_pop_nrn] = input_spike_prob + input_spike_probs[:, :, 3 * n_pop_nrn :] = input_spike_prob / 4.0 + input_spike_bools = input_spike_probs > np.random.rand(input_spike_probs.size).reshape(input_spike_probs.shape) + input_spike_bools[:, 0, :] = 0 # remove spikes in 0th time step of every sequence for technical reasons + + target_cues = np.zeros(n_batch, dtype=int) + target_cues[:] = np.sum(batched_cues, axis=1) > int(n_cues / 2) + + return input_spike_bools, target_cues + + +input_spike_prob = 0.04 # spike probability of frozen input noise +dtype_in_spks = np.float32 # data type of input spikes - for reproducing TF results set to np.float32 + +input_spike_bools_list = [] +target_cues_list = [] + +for iteration in range(n_iter): + input_spike_bools, target_cues = generate_evidence_accumulation_input_output( + n_batch, n_in, prob_group, input_spike_prob, n_cues, n_input_symbols, steps + ) + input_spike_bools_list.append(input_spike_bools) + target_cues_list.extend(target_cues.tolist()) + +input_spike_bools_arr = np.array(input_spike_bools_list).reshape(steps["task"], n_in) +timeline_task = np.arange(0.0, duration["task"], duration["step"]) + duration["offset_gen"] + +params_gen_spk_in = [ + {"spike_times": timeline_task[input_spike_bools_arr[:, nrn_in_idx]].astype(dtype_in_spks)} + for nrn_in_idx in range(n_in) +] + +target_rate_changes = np.zeros((n_out, n_batch * n_iter)) +target_rate_changes[np.array(target_cues_list), np.arange(n_batch * n_iter)] = 1 + +params_gen_rate_target = [ + { + "amplitude_times": np.arange(0.0, duration["task"], duration["sequence"]) + duration["total_offset"], + "amplitude_values": target_rate_changes[nrn_out_idx], + } + for nrn_out_idx in range(n_out) +] + + +#################### + +nest.SetStatus(gen_spk_in, params_gen_spk_in) +nest.SetStatus(gen_rate_target, params_gen_rate_target) + +# %% ########################################################################################################### +# Force final update +# ~~~~~~~~~~~~~~~~~~ +# Synapses only get active, that is, the correct weight update calculated and applied, when they transmit a +# spike. To still be able to read out the correct weights at the end of the simulation, we force spiking of the +# presynaptic neuron and thus an update of all synapses, including those that have not transmitted a spike in +# the last update interval, by sending a strong spike to all neurons that form the presynaptic side of an eprop +# synapse. This step is required purely for technical reasons. + +gen_spk_final_update = nest.Create("spike_generator", 1, {"spike_times": [duration["task"] + duration["delays"]]}) + +nest.Connect(gen_spk_final_update, nrns_in + nrns_rec, "all_to_all", {"weight": 1000.0}) + +# %% ########################################################################################################### +# Read out pre-training weights +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# Before we begin training, we read out the initial weight matrices so that we can eventually compare them to +# the optimized weights. + + +def get_weights(pop_pre, pop_post): + conns = nest.GetConnections(pop_pre, pop_post).get(["source", "target", "weight"]) + conns["senders"] = np.array(conns["source"]) - np.min(conns["source"]) + conns["targets"] = np.array(conns["target"]) - np.min(conns["target"]) + + conns["weight_matrix"] = np.zeros((len(pop_post), len(pop_pre))) + conns["weight_matrix"][conns["targets"], conns["senders"]] = conns["weight"] + return conns + + +weights_pre_train = { + "in_rec": get_weights(nrns_in, nrns_rec), + "rec_rec": get_weights(nrns_rec, nrns_rec), + "rec_out": get_weights(nrns_rec, nrns_out), +} + +# %% ########################################################################################################### +# Simulate +# ~~~~~~~~ +# We train the network by simulating for a set simulation time, determined by the number of iterations and the +# batch size and the length of one sequence. + +nest.Simulate(duration["sim"]) + +# %% ########################################################################################################### +# Read out post-training weights +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# After the training, we can read out the optimized final weights. + +weights_post_train = { + "in_rec": get_weights(nrns_in, nrns_rec), + "rec_rec": get_weights(nrns_rec, nrns_rec), + "rec_out": get_weights(nrns_rec, nrns_out), +} + +# %% ########################################################################################################### +# Read out recorders +# ~~~~~~~~~~~~~~~~~~ +# We can also retrieve the recorded history of the dynamic variables and weights, as well as detected spikes. + +events_mm_reg = mm_reg.get("events") +events_mm_ad = mm_ad.get("events") +events_mm_out = mm_out.get("events") +events_sr = sr.get("events") +events_wr = wr.get("events") + +# %% ########################################################################################################### +# Evaluate training error +# ~~~~~~~~~~~~~~~~~~~~~~~ +# We evaluate the network's training error by calculating a loss - in this case, the cross-entropy error between +# the integrated recurrent network activity and the target rate. + +readout_signal = events_mm_out["readout_signal"] # corresponds to softmax +target_signal = events_mm_out["target_signal"] +senders = events_mm_out["senders"] + +readout_signal = np.array([readout_signal[senders == i] for i in set(senders)]) +target_signal = np.array([target_signal[senders == i] for i in set(senders)]) + +readout_signal = readout_signal.reshape((n_out, n_iter, n_batch, steps["sequence"])) +readout_signal = readout_signal[:, :, :, -steps["learning_window"] :] + +target_signal = target_signal.reshape((n_out, n_iter, n_batch, steps["sequence"])) +target_signal = target_signal[:, :, :, -steps["learning_window"] :] + +loss = -np.mean(np.sum(target_signal * np.log(readout_signal), axis=0), axis=(1, 2)) + +y_prediction = np.argmax(np.mean(readout_signal, axis=3), axis=0) +y_target = np.argmax(np.mean(target_signal, axis=3), axis=0) +accuracy = np.mean((y_target == y_prediction), axis=1) +recall_errors = 1.0 - accuracy + +# %% ########################################################################################################### +# Plot results +# ~~~~~~~~~~~~ +# Then, we plot a series of plots. + +do_plotting = True # if True, plot the results + +if not do_plotting: + exit() + +colors = { + "blue": "#2854c5ff", + "red": "#e04b40ff", + "white": "#ffffffff", +} + +plt.rcParams.update( + { + "font.sans-serif": "Arial", + "axes.spines.right": False, + "axes.spines.top": False, + "axes.prop_cycle": cycler(color=[colors["blue"], colors["red"]]), + } +) + +# %% ########################################################################################################### +# Plot training error +# ................... +# We begin with two plots visualizing the training error of the network: the loss and the recall error, both +# plotted against the iterations. + +fig, axs = plt.subplots(2, 1, sharex=True) + +axs[0].plot(range(1, n_iter + 1), loss) +axs[0].set_ylabel(r"$E = -\sum_{t,k} \pi_k^{*,t} \log \pi_k^t$") + +axs[1].plot(range(1, n_iter + 1), recall_errors) +axs[1].set_ylabel("recall errors") + +axs[-1].set_xlabel("training iteration") +axs[-1].set_xlim(1, n_iter) +axs[-1].xaxis.get_major_locator().set_params(integer=True) + +fig.tight_layout() + +# %% ########################################################################################################### +# Plot spikes and dynamic variables +# ................................. +# This plotting routine shows how to plot all of the recorded dynamic variables and spikes across time. We take +# one snapshot in the first iteration and one snapshot at the end. + + +def plot_recordable(ax, events, recordable, ylabel, xlims): + for sender in set(events["senders"]): + idc_sender = events["senders"] == sender + idc_times = (events["times"][idc_sender] > xlims[0]) & (events["times"][idc_sender] < xlims[1]) + ax.plot(events["times"][idc_sender][idc_times], events[recordable][idc_sender][idc_times], lw=0.5) + ax.set_ylabel(ylabel) + margin = np.abs(np.max(events[recordable]) - np.min(events[recordable])) * 0.1 + ax.set_ylim(np.min(events[recordable]) - margin, np.max(events[recordable]) + margin) + + +def plot_spikes(ax, events, nrns, ylabel, xlims): + idc_times = (events["times"] > xlims[0]) & (events["times"] < xlims[1]) + idc_sender = np.isin(events["senders"][idc_times], nrns.tolist()) + senders_subset = events["senders"][idc_times][idc_sender] + times_subset = events["times"][idc_times][idc_sender] + + ax.scatter(times_subset, senders_subset, s=0.1) + ax.set_ylabel(ylabel) + margin = np.abs(np.max(senders_subset) - np.min(senders_subset)) * 0.1 + ax.set_ylim(np.min(senders_subset) - margin, np.max(senders_subset) + margin) + + +for xlims in [(0, steps["sequence"]), (steps["task"] - steps["sequence"], steps["task"])]: + fig, axs = plt.subplots(14, 1, sharex=True, figsize=(8, 14), gridspec_kw={"hspace": 0.4, "left": 0.2}) + + plot_spikes(axs[0], events_sr, nrns_in, r"$z_i$" + "\n", xlims) + plot_spikes(axs[1], events_sr, nrns_reg, r"$z_j$" + "\n", xlims) + + plot_recordable(axs[2], events_mm_reg, "V_m", r"$v_j$" + "\n(mV)", xlims) + plot_recordable(axs[3], events_mm_reg, "surrogate_gradient", r"$\psi_j$" + "\n", xlims) + plot_recordable(axs[4], events_mm_reg, "learning_signal", r"$L_j$" + "\n(pA)", xlims) + + plot_spikes(axs[5], events_sr, nrns_ad, r"$z_j$" + "\n", xlims) + + plot_recordable(axs[6], events_mm_ad, "V_m", r"$v_j$" + "\n(mV)", xlims) + plot_recordable(axs[7], events_mm_ad, "surrogate_gradient", r"$\psi_j$" + "\n", xlims) + plot_recordable(axs[8], events_mm_ad, "V_th_adapt", r"$A_j$" + "\n(mV)", xlims) + plot_recordable(axs[9], events_mm_ad, "learning_signal", r"$L_j$" + "\n(pA)", xlims) + + plot_recordable(axs[10], events_mm_out, "V_m", r"$v_k$" + "\n(mV)", xlims) + plot_recordable(axs[11], events_mm_out, "target_signal", r"$\pi^*_k$" + "\n", xlims) + plot_recordable(axs[12], events_mm_out, "readout_signal", r"$\pi_k$" + "\n", xlims) + plot_recordable(axs[13], events_mm_out, "error_signal", r"$\pi_k-\pi^*_k$" + "\n", xlims) + + axs[-1].set_xlabel(r"$t$ (ms)") + axs[-1].set_xlim(*xlims) + + fig.align_ylabels() + +# %% ########################################################################################################### +# Plot weight time courses +# ........................ +# Similarly, we can plot the weight histories. Note that the weight recorder, attached to the synapses, works +# differently than the other recorders. Since synapses only get activated when they transmit a spike, the weight +# recorder only records the weight in those moments. That is why the first weight registrations do not start in +# the first time step and we add the initial weights manually. + + +def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabel): + for sender in nrns_senders.tolist(): + for target in nrns_targets.tolist(): + idc_syn = (events["senders"] == sender) & (events["targets"] == target) + idc_syn_pre = (weights_pre_train[label]["source"] == sender) & ( + weights_pre_train[label]["target"] == target + ) + + times = [0.0] + events["times"][idc_syn].tolist() + weights = [weights_pre_train[label]["weight"][idc_syn_pre]] + events["weights"][idc_syn].tolist() + + ax.step(times, weights, c=colors["blue"]) + ax.set_ylabel(ylabel) + ax.set_ylim(-0.6, 0.6) + + +fig, axs = plt.subplots(3, 1, sharex=True, figsize=(3, 4)) + +plot_weight_time_course(axs[0], events_wr, nrns_in[:n_record_w], nrns_rec[:n_record_w], "in_rec", r"$W_\text{in}$ (pA)") +plot_weight_time_course( + axs[1], events_wr, nrns_rec[:n_record_w], nrns_rec[:n_record_w], "rec_rec", r"$W_\text{rec}$ (pA)" +) +plot_weight_time_course(axs[2], events_wr, nrns_rec[:n_record_w], nrns_out, "rec_out", r"$W_\text{out}$ (pA)") + +axs[-1].set_xlabel(r"$t$ (ms)") +axs[-1].set_xlim(0, steps["task"]) + +fig.align_ylabels() +fig.tight_layout() + +# %% ########################################################################################################### +# Plot weight matrices +# .................... +# If one is not interested in the time course of the weights, it is possible to read out only the initial and +# final weights, which requires less computing time and memory than the weight recorder approach. Here, we plot +# the corresponding weight matrices before and after the optimization. + +cmap = mpl.colors.LinearSegmentedColormap.from_list( + "cmap", ((0.0, colors["blue"]), (0.5, colors["white"]), (1.0, colors["red"])) +) + +fig, axs = plt.subplots(3, 2, sharex="col", sharey="row") + +all_w_extrema = [] + +for k in weights_pre_train.keys(): + w_pre = weights_pre_train[k]["weight"] + w_post = weights_post_train[k]["weight"] + all_w_extrema.append([np.min(w_pre), np.max(w_pre), np.min(w_post), np.max(w_post)]) + +args = {"cmap": cmap, "vmin": np.min(all_w_extrema), "vmax": np.max(all_w_extrema)} + +for i, weights in zip([0, 1], [weights_pre_train, weights_post_train]): + axs[0, i].pcolormesh(weights["in_rec"]["weight_matrix"].T, **args) + axs[1, i].pcolormesh(weights["rec_rec"]["weight_matrix"], **args) + cmesh = axs[2, i].pcolormesh(weights["rec_out"]["weight_matrix"], **args) + + axs[2, i].set_xlabel("recurrent\nneurons") + +axs[0, 0].set_ylabel("input\nneurons") +axs[1, 0].set_ylabel("recurrent\nneurons") +axs[2, 0].set_ylabel("readout\nneurons") +fig.align_ylabels(axs[:, 0]) + +axs[0, 0].text(0.5, 1.1, "pre-training", transform=axs[0, 0].transAxes, ha="center") +axs[0, 1].text(0.5, 1.1, "post-training", transform=axs[0, 1].transAxes, ha="center") + +axs[2, 0].yaxis.get_major_locator().set_params(integer=True) + +cbar = plt.colorbar(cmesh, cax=axs[1, 1].inset_axes([1.1, 0.2, 0.05, 0.8]), label="weight (pA)") + +fig.tight_layout() + +plt.show() diff --git a/pynest/examples/eprop_plasticity/eprop_supervised_classification_schematic_evidence-accumulation.png b/pynest/examples/eprop_plasticity/eprop_supervised_classification_schematic_evidence-accumulation.png new file mode 100644 index 0000000000..60738a57b2 Binary files /dev/null and b/pynest/examples/eprop_plasticity/eprop_supervised_classification_schematic_evidence-accumulation.png differ diff --git a/pynest/examples/eprop_plasticity/eprop_supervised_regression_handwriting.py b/pynest/examples/eprop_plasticity/eprop_supervised_regression_handwriting.py new file mode 100644 index 0000000000..6c2f5aa46c --- /dev/null +++ b/pynest/examples/eprop_plasticity/eprop_supervised_regression_handwriting.py @@ -0,0 +1,719 @@ +# -*- coding: utf-8 -*- +# +# eprop_supervised_regression_handwriting.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +r""" +Tutorial on learning to generate handwritten text with e-prop +------------------------------------------------------------- + +Training a regression model using supervised e-prop plasticity to generate handwritten text + +Description +~~~~~~~~~~~ + +This script demonstrates supervised learning of a regression task with a recurrent spiking neural network that +is equipped with the eligibility propagation (e-prop) plasticity mechanism by Bellec et al. [1]_. + +This type of learning is demonstrated at the proof-of-concept task in [1]_. We based this script on their +TensorFlow script given in [2]_ and changed the task as well as the parameters slightly. + + +In this task, the network learns to generate an arbitrary N-dimensional temporal pattern. Here, the network +learns to reproduce with its overall spiking activity a two-dimensional, roughly one-second-long target signal +which encode the x and y coordinates of the handwritten word "chaos". + +.. image:: ../../../../pynest/examples/eprop_plasticity/eprop_supervised_regression_schematic_handwriting.png + :width: 70 % + :alt: See Figure 1 below. + :align: center + +Learning in the neural network model is achieved by optimizing the connection weights with e-prop plasticity. +This plasticity rule requires a specific network architecture depicted in Figure 1. The neural network model +consists of a recurrent network that receives frozen noise input from Poisson generators and projects onto two +readout neurons. Each individual readout signal denoted as :math:`y_k` is compared with a corresponding target +signal represented as :math:`y_k^*`. The network's training error is assessed by employing a mean-squared error +loss. + +Details on the event-based NEST implementation of e-prop can be found in [3]_. + +The development of this task and the hyper-parameter optimization were conducted by Agnes Korcsak-Gorzo and +Charl Linssen, inspired by activities and feedback received at the CapoCaccia Workshop toward Neuromorphic +Intelligence 2023. + +References +~~~~~~~~~~ + +.. [1] Bellec G, Scherr F, Subramoney F, Hajek E, Salaj D, Legenstein R, Maass W (2020). A solution to the + learning dilemma for recurrent networks of spiking neurons. Nature Communications, 11:3625. + https://doi.org/10.1038/s41467-020-17236-y + +.. [2] https://github.com/IGITUGraz/eligibility_propagation/blob/master/Figure_3_and_S7_e_prop_tutorials/tutorial_pattern_generation.py + +.. [3] Korcsak-Gorzo A, Stapmanns J, Espinoza Valverde JA, Dahmen D, van Albada SJ, Bolten M, Diesmann M. + Event-based implementation of eligibility propagation (in preparation) +""" # pylint: disable=line-too-long # noqa: E501 + +# %% ########################################################################################################### +# Import libraries +# ~~~~~~~~~~~~~~~~ +# We begin by importing all libraries required for the simulation, analysis, and visualization. + +import matplotlib as mpl +import matplotlib.pyplot as plt +import nest +import numpy as np +from cycler import cycler +from IPython.display import Image + +# %% ########################################################################################################### +# Schematic of network architecture +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# This figure, identical to the one in the description, shows the required network architecture in the center, +# the input and output of the pattern generation task above, and lists of the required NEST device, neuron, and +# synapse models below. The connections that must be established are numbered 1 to 6. + +try: + Image(filename="./eprop_supervised_regression_schematic_handwriting.png") +except Exception: + pass + +# %% ########################################################################################################### +# Setup +# ~~~~~ + +# %% ########################################################################################################### +# Initialize random generator +# ........................... +# We seed the numpy random generator, which will generate random initial weights as well as random input and +# output. + +rng_seed = 1 # numpy random seed +np.random.seed(rng_seed) # fix numpy random seed + +# %% ########################################################################################################### +# Define timing of task +# ..................... +# The task's temporal structure is then defined, once as time steps and once as durations in milliseconds. + +n_batch = 1 # batch size +n_iter = 5 # number of iterations, 5000 for good convergence + +data_file_name = "chaos_handwriting.txt" # name of file with task data +data = np.loadtxt(data_file_name) + +steps = { + "data_point": 8, # time steps of one data point +} + +steps["sequence"] = len(data) * steps["data_point"] # time steps of one full sequence +steps["learning_window"] = steps["sequence"] # time steps of window with non-zero learning signals +steps["task"] = n_iter * n_batch * steps["sequence"] # time steps of task + +steps.update( + { + "offset_gen": 1, # offset since generator signals start from time step 1 + "delay_in_rec": 1, # connection delay between input and recurrent neurons + "delay_rec_out": 1, # connection delay between recurrent and output neurons + "delay_out_norm": 1, # connection delay between output neurons for normalization + "extension_sim": 1, # extra time step to close right-open simulation time interval in Simulate() + } +) + +steps["delays"] = steps["delay_in_rec"] + steps["delay_rec_out"] + steps["delay_out_norm"] # time steps of delays + +steps["total_offset"] = steps["offset_gen"] + steps["delays"] # time steps of total offset + +steps["sim"] = steps["task"] + steps["total_offset"] + steps["extension_sim"] # time steps of simulation + +duration = {"step": 1.0} # ms, temporal resolution of the simulation + +duration.update({key: value * duration["step"] for key, value in steps.items()}) # ms, durations + +# %% ########################################################################################################### +# Set up simulation +# ................. +# As last step of the setup, we reset the NEST kernel to remove all existing NEST simulation settings and +# objects and set some NEST kernel parameters, some of which are e-prop-related. + +params_setup = { + "eprop_learning_window": duration["learning_window"], + "eprop_reset_neurons_on_update": True, # if True, reset dynamic variables at start of each update interval + "eprop_update_interval": duration["sequence"], # ms, time interval for updating the synaptic weights + "print_time": False, # if True, print time progress bar during simulation, set False if run as code cell + "resolution": duration["step"], + "total_num_virtual_procs": 1, # number of virtual processes, set in case of distributed computing + "rng_seed": rng_seed, # seed for NEST random generator +} + +#################### + +nest.ResetKernel() +nest.set(**params_setup) + +# %% ########################################################################################################### +# Create neurons +# ~~~~~~~~~~~~~~ +# We proceed by creating a certain number of input, recurrent, and readout neurons and setting their parameters. +# Additionally, we already create an input spike generator and an output target rate generator, which we will +# configure later. + +n_in = 100 # number of input neurons +n_rec = 200 # number of recurrent neurons +n_out = 2 # number of readout neurons + +tau_m_mean = 30.0 # ms, mean of membrane time constant distribution + +params_nrn_rec = { + "adapt_tau": 2000.0, # ms, time constant of adaptive threshold + "C_m": 250.0, # pF, membrane capacitance - takes effect only if neurons get current input (here not the case) + "c_reg": 150.0, # firing rate regularization scaling + "E_L": 0.0, # mV, leak / resting membrane potential + "f_target": 20.0, # spikes/s, target firing rate for firing rate regularization + "gamma": 0.3, # scaling of the pseudo derivative + "I_e": 0.0, # pA, external current input + "regular_spike_arrival": False, # If True, input spikes arrive at end of time step, if False at beginning + "surrogate_gradient_function": "piecewise_linear", # surrogate gradient / pseudo-derivative function + "t_ref": 0.0, # ms, duration of refractory period + "tau_m": nest.random.normal(mean=tau_m_mean, std=2.0), # ms, membrane time constant + "V_m": 0.0, # mV, initial value of the membrane voltage + "V_th": 0.03, # mV, spike threshold membrane voltage +} + +params_nrn_rec["adapt_beta"] = ( + 1.7 * (1.0 - np.exp(-1 / params_nrn_rec["adapt_tau"])) / (1.0 - np.exp(-1.0 / tau_m_mean)) +) # prefactor of adaptive threshold + +params_nrn_out = { + "C_m": 1.0, + "E_L": 0.0, + "I_e": 0.0, + "loss": "mean_squared_error", # loss function + "regular_spike_arrival": False, + "tau_m": 50.0, + "V_m": 0.0, +} + +#################### + +# Intermediate parrot neurons required between input spike generators and recurrent neurons, +# since devices cannot establish plastic synapses for technical reasons + +gen_spk_in = nest.Create("spike_generator", n_in) +nrns_in = nest.Create("parrot_neuron", n_in) + +# The suffix _bsshslm_2020 follows the NEST convention to indicate in the model name the paper +# that introduced it by the first letter of the authors' last names and the publication year. + +nrns_rec = nest.Create("eprop_iaf_adapt_bsshslm_2020", n_rec, params_nrn_rec) +nrns_out = nest.Create("eprop_readout_bsshslm_2020", n_out, params_nrn_out) +gen_rate_target = nest.Create("step_rate_generator", n_out) + + +# %% ########################################################################################################### +# Create recorders +# ~~~~~~~~~~~~~~~~ +# We also create recorders, which, while not required for the training, will allow us to track various dynamic +# variables of the neurons, spikes, and changes in synaptic weights. To save computing time and memory, the +# recorders, the recorded variables, neurons, and synapses can be limited to the ones relevant to the +# experiment, and the recording interval can be increased (see the documentation on the specific recorders). By +# default, recordings are stored in memory but can also be written to file. + +n_record = 1 # number of neurons to record dynamic variables from - this script requires n_record >= 1 +n_record_w = 3 # number of senders and targets to record weights from - this script requires n_record_w >=1 + +if n_record == 0 or n_record_w == 0: + raise ValueError("n_record and n_record_w >= 1 required") + +params_mm_rec = { + "interval": duration["step"], # interval between two recorded time points + "record_from": [ + "V_m", + "surrogate_gradient", + "learning_signal", + "V_th_adapt", + "adaptation", + ], # dynamic variables to record + "start": duration["offset_gen"] + duration["delay_in_rec"], # start time of recording + "stop": duration["offset_gen"] + duration["delay_in_rec"] + duration["task"], # stop time of recording +} + +params_mm_out = { + "interval": duration["step"], + "record_from": ["V_m", "readout_signal", "readout_signal_unnorm", "target_signal", "error_signal"], + "start": duration["total_offset"], + "stop": duration["total_offset"] + duration["task"], +} + +params_wr = { + "senders": nrns_in[:n_record_w] + nrns_rec[:n_record_w], # limit senders to subsample weights to record + "targets": nrns_rec[:n_record_w] + nrns_out, # limit targets to subsample weights to record from + "start": duration["total_offset"], + "stop": duration["total_offset"] + duration["task"], +} + +params_sr = { + "start": duration["total_offset"], + "stop": duration["total_offset"] + duration["task"], +} + +#################### + +mm_rec = nest.Create("multimeter", params_mm_rec) +mm_out = nest.Create("multimeter", params_mm_out) +sr = nest.Create("spike_recorder", params_sr) +wr = nest.Create("weight_recorder", params_wr) + +nrns_rec_record = nrns_rec[:n_record] + +# %% ########################################################################################################### +# Create connections +# ~~~~~~~~~~~~~~~~~~ +# Now, we define the connectivity and set up the synaptic parameters, with the synaptic weights drawn from +# normal distributions. After these preparations, we establish the enumerated connections of the core network, +# as well as additional connections to the recorders. + +params_conn_all_to_all = {"rule": "all_to_all", "allow_autapses": False} +params_conn_one_to_one = {"rule": "one_to_one"} + +dtype_weights = np.float32 # data type of weights - for reproducing TF results set to np.float32 +weights_in_rec = np.array(np.random.randn(n_in, n_rec).T / np.sqrt(n_in), dtype=dtype_weights) +weights_rec_rec = np.array(np.random.randn(n_rec, n_rec).T / np.sqrt(n_rec), dtype=dtype_weights) +np.fill_diagonal(weights_rec_rec, 0.0) # since no autapses set corresponding weights to zero +weights_rec_out = np.array(np.random.randn(n_rec, n_out).T / np.sqrt(n_rec), dtype=dtype_weights) +weights_out_rec = np.array(np.random.randn(n_rec, n_out) / np.sqrt(n_rec), dtype=dtype_weights) + +params_common_syn_eprop = { + "optimizer": { + "type": "adam", # algorithm to optimize the weights + "batch_size": n_batch, + "beta_1": 0.9, # exponential decay rate for 1st moment estimate of Adam optimizer + "beta_2": 0.999, # exponential decay rate for 2nd moment raw estimate of Adam optimizer + "epsilon": 1e-8, # small numerical stabilization constant of Adam optimizer + "eta": 5e-3, # learning rate + "Wmin": -100.0, # pA, minimal limit of the synaptic weights + "Wmax": 100.0, # pA, maximal limit of the synaptic weights + }, + "average_gradient": False, # if True, average the gradient over the learning window + "weight_recorder": wr, +} + +params_syn_base = { + "synapse_model": "eprop_synapse_bsshslm_2020", + "delay": duration["step"], # ms, dendritic delay + "tau_m_readout": params_nrn_out["tau_m"], # ms, for technical reasons pass readout neuron membrane time constant +} + +params_syn_in = params_syn_base.copy() +params_syn_in["weight"] = weights_in_rec # pA, initial values for the synaptic weights + +params_syn_rec = params_syn_base.copy() +params_syn_rec["weight"] = weights_rec_rec + +params_syn_out = params_syn_base.copy() +params_syn_out["weight"] = weights_rec_out + +params_syn_feedback = { + "synapse_model": "eprop_learning_signal_connection_bsshslm_2020", + "delay": duration["step"], + "weight": weights_out_rec, +} + +params_syn_rate_target = { + "synapse_model": "rate_connection_delayed", + "delay": duration["step"], + "receptor_type": 2, # receptor type over which readout neuron receives target signal +} + +params_syn_static = { + "synapse_model": "static_synapse", + "delay": duration["step"], +} + +params_init_optimizer = { + "optimizer": { + "m": 0.0, # initial 1st moment estimate m of Adam optimizer + "v": 0.0, # initial 2nd moment raw estimate v of Adam optimizer + } +} + +#################### + +nest.SetDefaults("eprop_synapse_bsshslm_2020", params_common_syn_eprop) + +nest.Connect(gen_spk_in, nrns_in, params_conn_one_to_one, params_syn_static) # connection 1 +nest.Connect(nrns_in, nrns_rec, params_conn_all_to_all, params_syn_in) # connection 2 +nest.Connect(nrns_rec, nrns_rec, params_conn_all_to_all, params_syn_rec) # connection 3 +nest.Connect(nrns_rec, nrns_out, params_conn_all_to_all, params_syn_out) # connection 4 +nest.Connect(nrns_out, nrns_rec, params_conn_all_to_all, params_syn_feedback) # connection 5 +nest.Connect(gen_rate_target, nrns_out, params_conn_one_to_one, params_syn_rate_target) # connection 6 + +nest.Connect(nrns_in + nrns_rec, sr, params_conn_all_to_all, params_syn_static) + +nest.Connect(mm_rec, nrns_rec_record, params_conn_all_to_all, params_syn_static) +nest.Connect(mm_out, nrns_out, params_conn_all_to_all, params_syn_static) + +# After creating the connections, we can individually initialize the optimizer's +# dynamic variables for single synapses (here exemplarily for two connections). + +nest.GetConnections(nrns_rec[0], nrns_rec[1:3]).set([params_init_optimizer] * 2) + +# %% ########################################################################################################### +# Create input +# ~~~~~~~~~~~~ +# We generate some frozen Poisson spike noise of a fixed rate that is repeated in each iteration and feed these +# spike times to the previously created input spike generator. The network will use these spike times as a +# temporal backbone for encoding the target signal into its recurrent spiking activity. + +input_spike_prob = 0.05 # spike probability of frozen input noise +dtype_in_spks = np.float32 # data type of input spikes - for reproducing TF results set to np.float32 + +input_spike_bools = (np.random.rand(steps["sequence"], n_in) < input_spike_prob).swapaxes(0, 1) +input_spike_bools[:, 0] = 0 # remove spikes in 0th time step of every sequence for technical reasons + +sequence_starts = np.arange(0.0, duration["task"], duration["sequence"]) + duration["offset_gen"] +params_gen_spk_in = [] +for input_spike_bool in input_spike_bools: + input_spike_times = np.arange(0.0, duration["sequence"], duration["step"])[input_spike_bool] + input_spike_times_all = [input_spike_times + start for start in sequence_starts] + params_gen_spk_in.append({"spike_times": np.hstack(input_spike_times_all).astype(dtype_in_spks)}) + +#################### + +nest.SetStatus(gen_spk_in, params_gen_spk_in) + +# %% ########################################################################################################### +# Create output +# ~~~~~~~~~~~~~ +# Then, we load the x and y values of an image of the word "chaos" written by hand and construct a roughly +# two-second long target signal from it. This signal, like the input, is repeated for all iterations and fed +# into the rate generator that was previously created. + +x_eval = np.arange(steps["sequence"]) / steps["data_point"] +x_data = np.arange(steps["sequence"] // steps["data_point"]) + +target_signal_list = [] +for y_data in np.cumsum(data, axis=0).T: + y_data /= np.max(np.abs(y_data)) + y_data -= np.mean(y_data) + target_signal_list.append(np.interp(x_eval, x_data, y_data)) + +params_gen_rate_target = [] + +for target_signal in target_signal_list: + params_gen_rate_target.append( + { + "amplitude_times": np.arange(0.0, duration["task"], duration["step"]) + duration["total_offset"], + "amplitude_values": np.tile(target_signal, n_iter * n_batch), + } + ) + +#################### + +nest.SetStatus(gen_rate_target, params_gen_rate_target) + +# %% ########################################################################################################### +# Force final update +# ~~~~~~~~~~~~~~~~~~ +# Synapses only get active, that is, the correct weight update calculated and applied, when they transmit a +# spike. To still be able to read out the correct weights at the end of the simulation, we force spiking of the +# presynaptic neuron and thus an update of all synapses, including those that have not transmitted a spike in +# the last update interval, by sending a strong spike to all neurons that form the presynaptic side of an eprop +# synapse. This step is required purely for technical reasons. + +gen_spk_final_update = nest.Create("spike_generator", 1, {"spike_times": [duration["task"] + duration["delays"]]}) + +nest.Connect(gen_spk_final_update, nrns_in + nrns_rec, "all_to_all", {"weight": 1000.0}) + +# %% ########################################################################################################### +# Read out pre-training weights +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# Before we begin training, we read out the initial weight matrices so that we can eventually compare them to +# the optimized weights. + + +def get_weights(pop_pre, pop_post): + conns = nest.GetConnections(pop_pre, pop_post).get(["source", "target", "weight"]) + conns["senders"] = np.array(conns["source"]) - np.min(conns["source"]) + conns["targets"] = np.array(conns["target"]) - np.min(conns["target"]) + + conns["weight_matrix"] = np.zeros((len(pop_post), len(pop_pre))) + conns["weight_matrix"][conns["targets"], conns["senders"]] = conns["weight"] + return conns + + +weights_pre_train = { + "in_rec": get_weights(nrns_in, nrns_rec), + "rec_rec": get_weights(nrns_rec, nrns_rec), + "rec_out": get_weights(nrns_rec, nrns_out), +} + +# %% ########################################################################################################### +# Simulate +# ~~~~~~~~ +# We train the network by simulating for a set simulation time, determined by the number of iterations and the +# batch size and the length of one sequence. + +nest.Simulate(duration["sim"]) + +# %% ########################################################################################################### +# Read out post-training weights +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# After the training, we can read out the optimized final weights. + +weights_post_train = { + "in_rec": get_weights(nrns_in, nrns_rec), + "rec_rec": get_weights(nrns_rec, nrns_rec), + "rec_out": get_weights(nrns_rec, nrns_out), +} + +# %% ########################################################################################################### +# Read out recorders +# ~~~~~~~~~~~~~~~~~~ +# We can also retrieve the recorded history of the dynamic variables and weights, as well as detected spikes. + +events_mm_rec = mm_rec.get("events") +events_mm_out = mm_out.get("events") +events_sr = sr.get("events") +events_wr = wr.get("events") + +# %% ########################################################################################################### +# Evaluate training error +# ~~~~~~~~~~~~~~~~~~~~~~~ +# We evaluate the network's training error by calculating a loss - in this case, the mean squared error between +# the integrated recurrent network activity and the target rate. + +readout_signal = events_mm_out["readout_signal"] +target_signal = events_mm_out["target_signal"] +senders = events_mm_out["senders"] + +loss_list = [] +for sender in set(senders): + idc = senders == sender + error = (readout_signal[idc] - target_signal[idc]) ** 2 + loss_list.append(0.5 * np.add.reduceat(error, np.arange(0, steps["task"], steps["sequence"]))) + +loss = np.sum(loss_list, axis=0) + +# %% ########################################################################################################### +# Plot results +# ~~~~~~~~~~~~ +# Then, we plot a series of plots. + +do_plotting = True # if True, plot the results + +if not do_plotting: + exit() + +colors = { + "blue": "#2854c5ff", + "red": "#e04b40ff", + "white": "#ffffffff", +} + +plt.rcParams.update( + { + "font.sans-serif": "Arial", + "axes.spines.right": False, + "axes.spines.top": False, + "axes.prop_cycle": cycler(color=[colors["blue"], colors["red"]]), + } +) + +# %% ########################################################################################################### +# Plot pattern +# ............ +# First, we visualize the created pattern and plot the target for comparison. The outputs of the two readout +# neurons encode the horizontal and vertical coordinate of the pattern respectively. + +fig, ax = plt.subplots() + +ax.plot( + readout_signal[senders == list(set(senders))[0]][-steps["sequence"] :], + -readout_signal[senders == list(set(senders))[1]][-steps["sequence"] :], + c=colors["red"], + label="readout", +) + +ax.plot( + target_signal[senders == list(set(senders))[0]][-steps["sequence"] :], + -target_signal[senders == list(set(senders))[1]][-steps["sequence"] :], + c=colors["blue"], + label="target", +) + +ax.set_xlabel(r"$y_0$ and $y^*_0$") +ax.set_ylabel(r"$y_1$ and $y^*_1$") + +ax.axis("equal") + +fig.tight_layout() + +# %% ########################################################################################################### +# Plot training error +# ................... +# We begin with a plot visualizing the training error of the network: the loss plotted against the iterations. + +fig, ax = plt.subplots() + +ax.plot(range(1, n_iter + 1), loss_list[0], label=r"$E_0$", alpha=0.8, c=colors["blue"], ls="--") +ax.plot(range(1, n_iter + 1), loss_list[1], label=r"$E_1$", alpha=0.8, c=colors["blue"], ls="dotted") +ax.plot(range(1, n_iter + 1), loss, label=r"$E$", c=colors["blue"]) +ax.set_ylabel(r"$E = \frac{1}{2} \sum_{t,k} \left( y_k^t -y_k^{*,t}\right)^2$") +ax.set_xlabel("training iteration") +ax.set_xlim(1, n_iter) +ax.xaxis.get_major_locator().set_params(integer=True) +ax.legend(bbox_to_anchor=(1.01, 0.5), loc="center left") +fig.tight_layout() + +# %% ########################################################################################################### +# Plot spikes and dynamic variables +# ................................. +# This plotting routine shows how to plot all of the recorded dynamic variables and spikes across time. We take +# one snapshot in the first iteration and one snapshot at the end. + + +def plot_recordable(ax, events, recordable, ylabel, xlims): + for sender in set(events["senders"]): + idc_sender = events["senders"] == sender + idc_times = (events["times"][idc_sender] > xlims[0]) & (events["times"][idc_sender] < xlims[1]) + ax.plot(events["times"][idc_sender][idc_times], events[recordable][idc_sender][idc_times], lw=0.5) + ax.set_ylabel(ylabel) + margin = np.abs(np.max(events[recordable]) - np.min(events[recordable])) * 0.1 + ax.set_ylim(np.min(events[recordable]) - margin, np.max(events[recordable]) + margin) + + +def plot_spikes(ax, events, nrns, ylabel, xlims): + idc_times = (events["times"] > xlims[0]) & (events["times"] < xlims[1]) + idc_sender = np.isin(events["senders"][idc_times], nrns.tolist()) + senders_subset = events["senders"][idc_times][idc_sender] + times_subset = events["times"][idc_times][idc_sender] + + ax.scatter(times_subset, senders_subset, s=0.1) + ax.set_ylabel(ylabel) + margin = np.abs(np.max(senders_subset) - np.min(senders_subset)) * 0.1 + ax.set_ylim(np.min(senders_subset) - margin, np.max(senders_subset) + margin) + + +for xlims in [(0, steps["sequence"]), (steps["task"] - steps["sequence"], steps["task"])]: + fig, axs = plt.subplots(12, 1, sharex=True, figsize=(8, 12), gridspec_kw={"hspace": 0.4, "left": 0.2}) + + plot_spikes(axs[0], events_sr, nrns_in, r"$z_i$" + "\n", xlims) + plot_spikes(axs[1], events_sr, nrns_rec, r"$z_j$" + "\n", xlims) + + plot_spikes(axs[3], events_sr, nrns_rec, r"$z_j$" + "\n", xlims) + + plot_recordable(axs[4], events_mm_rec, "V_m", r"$v_j$" + "\n(mV)", xlims) + plot_recordable(axs[5], events_mm_rec, "surrogate_gradient", r"$\psi_j$" + "\n", xlims) + plot_recordable(axs[6], events_mm_rec, "V_th_adapt", r"$A_j$" + "\n(mV)", xlims) + plot_recordable(axs[7], events_mm_rec, "learning_signal", r"$L_j$" + "\n(pA)", xlims) + + plot_recordable(axs[8], events_mm_out, "V_m", r"$v_k$" + "\n(mV)", xlims) + plot_recordable(axs[9], events_mm_out, "target_signal", r"$y^*_k$" + "\n", xlims) + plot_recordable(axs[10], events_mm_out, "readout_signal", r"$y_k$" + "\n", xlims) + plot_recordable(axs[11], events_mm_out, "error_signal", r"$y_k-y^*_k$" + "\n", xlims) + + axs[-1].set_xlabel(r"$t$ (ms)") + axs[-1].set_xlim(*xlims) + + fig.align_ylabels() + +# %% ########################################################################################################### +# Plot weight time courses +# ........................ +# Similarly, we can plot the weight histories. Note that the weight recorder, attached to the synapses, works +# differently than the other recorders. Since synapses only get activated when they transmit a spike, the weight +# recorder only records the weight in those moments. That is why the first weight registrations do not start in +# the first time step and we add the initial weights manually. + + +def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabel): + for sender in nrns_senders.tolist(): + for target in nrns_targets.tolist(): + idc_syn = (events["senders"] == sender) & (events["targets"] == target) + idc_syn_pre = (weights_pre_train[label]["source"] == sender) & ( + weights_pre_train[label]["target"] == target + ) + + times = [0.0] + events["times"][idc_syn].tolist() + weights = [weights_pre_train[label]["weight"][idc_syn_pre]] + events["weights"][idc_syn].tolist() + + ax.step(times, weights, c=colors["blue"]) + ax.set_ylabel(ylabel) + ax.set_ylim(-0.6, 0.6) + + +fig, axs = plt.subplots(3, 1, sharex=True, figsize=(3, 4)) + +plot_weight_time_course(axs[0], events_wr, nrns_in[:n_record_w], nrns_rec[:n_record_w], "in_rec", r"$W_\text{in}$ (pA)") +plot_weight_time_course( + axs[1], events_wr, nrns_rec[:n_record_w], nrns_rec[:n_record_w], "rec_rec", r"$W_\text{rec}$ (pA)" +) +plot_weight_time_course(axs[2], events_wr, nrns_rec[:n_record_w], nrns_out, "rec_out", r"$W_\text{out}$ (pA)") + +axs[-1].set_xlabel(r"$t$ (ms)") +axs[-1].set_xlim(0, steps["task"]) + +fig.align_ylabels() +fig.tight_layout() + +# %% ########################################################################################################### +# Plot weight matrices +# .................... +# If one is not interested in the time course of the weights, it is possible to read out only the initial and +# final weights, which requires less computing time and memory than the weight recorder approach. Here, we plot +# the corresponding weight matrices before and after the optimization. + +cmap = mpl.colors.LinearSegmentedColormap.from_list( + "cmap", ((0.0, colors["blue"]), (0.5, colors["white"]), (1.0, colors["red"])) +) + +fig, axs = plt.subplots(3, 2, sharex="col", sharey="row") + +all_w_extrema = [] + +for k in weights_pre_train.keys(): + w_pre = weights_pre_train[k]["weight"] + w_post = weights_post_train[k]["weight"] + all_w_extrema.append([np.min(w_pre), np.max(w_pre), np.min(w_post), np.max(w_post)]) + +args = {"cmap": cmap, "vmin": np.min(all_w_extrema), "vmax": np.max(all_w_extrema)} + +for i, weights in zip([0, 1], [weights_pre_train, weights_post_train]): + axs[0, i].pcolormesh(weights["in_rec"]["weight_matrix"].T, **args) + axs[1, i].pcolormesh(weights["rec_rec"]["weight_matrix"], **args) + cmesh = axs[2, i].pcolormesh(weights["rec_out"]["weight_matrix"], **args) + + axs[2, i].set_xlabel("recurrent\nneurons") + +axs[0, 0].set_ylabel("input\nneurons") +axs[1, 0].set_ylabel("recurrent\nneurons") +axs[2, 0].set_ylabel("readout\nneurons") +fig.align_ylabels(axs[:, 0]) + +axs[0, 0].text(0.5, 1.1, "pre-training", transform=axs[0, 0].transAxes, ha="center") +axs[0, 1].text(0.5, 1.1, "post-training", transform=axs[0, 1].transAxes, ha="center") + +axs[2, 0].yaxis.get_major_locator().set_params(integer=True) + +cbar = plt.colorbar(cmesh, cax=axs[1, 1].inset_axes([1.1, 0.2, 0.05, 0.8]), label="weight (pA)") + +fig.tight_layout() + +plt.show() diff --git a/pynest/examples/eprop_plasticity/eprop_supervised_regression_infinite-loop.py b/pynest/examples/eprop_plasticity/eprop_supervised_regression_infinite-loop.py new file mode 100644 index 0000000000..1ed14fd4a6 --- /dev/null +++ b/pynest/examples/eprop_plasticity/eprop_supervised_regression_infinite-loop.py @@ -0,0 +1,713 @@ +# -*- coding: utf-8 -*- +# +# eprop_supervised_regression_infinite-loop.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +r""" +Tutorial on learning to generate an infinite loop with e-prop +------------------------------------------------------------- + +Training a regression model using supervised e-prop plasticity to generate an infinite loop + +Description +~~~~~~~~~~~ + +This script demonstrates supervised learning of a regression task with a recurrent spiking neural network that +is equipped with the eligibility propagation (e-prop) plasticity mechanism by Bellec et al. [1]_. + +This type of learning is demonstrated at the proof-of-concept task in [1]_. We based this script on their +TensorFlow script given in [2]_ and changed the task as well as the parameters slightly. + + +In this task, the network learns to generate an arbitrary N-dimensional temporal pattern. Here, the network +learns to reproduce with its overall spiking activity a two-dimensional, roughly two-second-long target signal +which encode the x and y coordinates of an infinite-loop. + +.. image:: ../../../../pynest/examples/eprop_plasticity/eprop_supervised_regression_schematic_infinite-loop.png + :width: 70 % + :alt: See Figure 1 below. + :align: center + +Learning in the neural network model is achieved by optimizing the connection weights with e-prop plasticity. +This plasticity rule requires a specific network architecture depicted in Figure 1. The neural network model +consists of a recurrent network that receives frozen noise input from Poisson generators and projects onto two +readout neurons. Each individual readout signal denoted as :math:`y_k` is compared with a corresponding target +signal represented as :math:`y_k^*`. The network's training error is assessed by employing a mean-squared error +loss. + +Details on the event-based NEST implementation of e-prop can be found in [3]_. + +The development of this task and the hyper-parameter optimization were conducted by Agnes Korcsak-Gorzo and +Charl Linssen, inspired by activities and feedback received at the CapoCaccia Workshop toward Neuromorphic +Intelligence 2023. + +References +~~~~~~~~~~ + +.. [1] Bellec G, Scherr F, Subramoney F, Hajek E, Salaj D, Legenstein R, Maass W (2020). A solution to the + learning dilemma for recurrent networks of spiking neurons. Nature Communications, 11:3625. + https://doi.org/10.1038/s41467-020-17236-y + +.. [2] https://github.com/IGITUGraz/eligibility_propagation/blob/master/Figure_3_and_S7_e_prop_tutorials/tutorial_pattern_generation.py + +.. [3] Korcsak-Gorzo A, Stapmanns J, Espinoza Valverde JA, Dahmen D, van Albada SJ, Bolten M, Diesmann M. + Event-based implementation of eligibility propagation (in preparation) +""" # pylint: disable=line-too-long # noqa: E501 + +# %% ########################################################################################################### +# Import libraries +# ~~~~~~~~~~~~~~~~ +# We begin by importing all libraries required for the simulation, analysis, and visualization. + +import matplotlib as mpl +import matplotlib.pyplot as plt +import nest +import numpy as np +from cycler import cycler +from IPython.display import Image + +# %% ########################################################################################################### +# Schematic of network architecture +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# This figure, identical to the one in the description, shows the required network architecture in the center, +# the input and output of the pattern generation task above, and lists of the required NEST device, neuron, and +# synapse models below. The connections that must be established are numbered 1 to 6. + +try: + Image(filename="./eprop_supervised_regression_schematic_infinite-loop.png") +except Exception: + pass + +# %% ########################################################################################################### +# Setup +# ~~~~~ + +# %% ########################################################################################################### +# Initialize random generator +# ........................... +# We seed the numpy random generator, which will generate random initial weights as well as random input and +# output. + +rng_seed = 1 # numpy random seed +np.random.seed(rng_seed) # fix numpy random seed + +# %% ########################################################################################################### +# Define timing of task +# ..................... +# The task's temporal structure is then defined, once as time steps and once as durations in milliseconds. + +n_batch = 1 # batch size +n_iter = 5 # number of iterations, 5000 for good convergence + +steps = { + "sequence": 1258, # time steps of one full sequence +} + +steps["learning_window"] = steps["sequence"] # time steps of window with non-zero learning signals +steps["task"] = n_iter * n_batch * steps["sequence"] # time steps of task + +steps.update( + { + "offset_gen": 1, # offset since generator signals start from time step 1 + "delay_in_rec": 1, # connection delay between input and recurrent neurons + "delay_rec_out": 1, # connection delay between recurrent and output neurons + "delay_out_norm": 1, # connection delay between output neurons for normalization + "extension_sim": 1, # extra time step to close right-open simulation time interval in Simulate() + } +) + +steps["delays"] = steps["delay_in_rec"] + steps["delay_rec_out"] + steps["delay_out_norm"] # time steps of delays + +steps["total_offset"] = steps["offset_gen"] + steps["delays"] # time steps of total offset + +steps["sim"] = steps["task"] + steps["total_offset"] + steps["extension_sim"] # time steps of simulation + +duration = {"step": 1.0} # ms, temporal resolution of the simulation + +duration.update({key: value * duration["step"] for key, value in steps.items()}) # ms, durations + +# %% ########################################################################################################### +# Set up simulation +# ................. +# As last step of the setup, we reset the NEST kernel to remove all existing NEST simulation settings and +# objects and set some NEST kernel parameters, some of which are e-prop-related. + +params_setup = { + "eprop_learning_window": duration["learning_window"], + "eprop_reset_neurons_on_update": True, # if True, reset dynamic variables at start of each update interval + "eprop_update_interval": duration["sequence"], # ms, time interval for updating the synaptic weights + "print_time": False, # if True, print time progress bar during simulation, set False if run as code cell + "resolution": duration["step"], + "total_num_virtual_procs": 1, # number of virtual processes, set in case of distributed computing + "rng_seed": rng_seed, # seed for NEST random generator +} + +#################### + +nest.ResetKernel() +nest.set(**params_setup) + +# %% ########################################################################################################### +# Create neurons +# ~~~~~~~~~~~~~~ +# We proceed by creating a certain number of input, recurrent, and readout neurons and setting their parameters. +# Additionally, we already create an input spike generator and an output target rate generator, which we will +# configure later. + +n_in = 100 # number of input neurons +n_rec = 200 # number of recurrent neurons +n_out = 2 # number of readout neurons + +tau_m_mean = 30.0 # ms, mean of membrane time constant distribution + +params_nrn_rec = { + "adapt_tau": 2000.0, # ms, time constant of adaptive threshold + "C_m": 250.0, # pF, membrane capacitance - takes effect only if neurons get current input (here not the case) + "c_reg": 150.0, # firing rate regularization scaling + "E_L": 0.0, # mV, leak / resting membrane potential + "f_target": 20.0, # spikes/s, target firing rate for firing rate regularization + "gamma": 0.3, # scaling of the pseudo derivative + "I_e": 0.0, # pA, external current input + "regular_spike_arrival": False, # If True, input spikes arrive at end of time step, if False at beginning + "surrogate_gradient_function": "piecewise_linear", # surrogate gradient / pseudo-derivative function + "t_ref": 0.0, # ms, duration of refractory period + "tau_m": nest.random.normal(mean=tau_m_mean, std=2.0), # ms, membrane time constant + "V_m": 0.0, # mV, initial value of the membrane voltage + "V_th": 0.03, # mV, spike threshold membrane voltage +} + +params_nrn_rec["adapt_beta"] = ( + 1.7 * (1.0 - np.exp(-1 / params_nrn_rec["adapt_tau"])) / (1.0 - np.exp(-1.0 / tau_m_mean)) +) # prefactor of adaptive threshold + +params_nrn_out = { + "C_m": 1.0, + "E_L": 0.0, + "I_e": 0.0, + "loss": "mean_squared_error", # loss function + "regular_spike_arrival": False, + "tau_m": 50.0, + "V_m": 0.0, +} + +#################### + +# Intermediate parrot neurons required between input spike generators and recurrent neurons, +# since devices cannot establish plastic synapses for technical reasons + +gen_spk_in = nest.Create("spike_generator", n_in) +nrns_in = nest.Create("parrot_neuron", n_in) + +# The suffix _bsshslm_2020 follows the NEST convention to indicate in the model name the paper +# that introduced it by the first letter of the authors' last names and the publication year. + +nrns_rec = nest.Create("eprop_iaf_adapt_bsshslm_2020", n_rec, params_nrn_rec) +nrns_out = nest.Create("eprop_readout_bsshslm_2020", n_out, params_nrn_out) +gen_rate_target = nest.Create("step_rate_generator", n_out) + + +# %% ########################################################################################################### +# Create recorders +# ~~~~~~~~~~~~~~~~ +# We also create recorders, which, while not required for the training, will allow us to track various dynamic +# variables of the neurons, spikes, and changes in synaptic weights. To save computing time and memory, the +# recorders, the recorded variables, neurons, and synapses can be limited to the ones relevant to the +# experiment, and the recording interval can be increased (see the documentation on the specific recorders). By +# default, recordings are stored in memory but can also be written to file. + +n_record = 1 # number of neurons to record dynamic variables from - this script requires n_record >= 1 +n_record_w = 3 # number of senders and targets to record weights from - this script requires n_record_w >=1 + +if n_record == 0 or n_record_w == 0: + raise ValueError("n_record and n_record_w >= 1 required") + +params_mm_rec = { + "interval": duration["step"], # interval between two recorded time points + "record_from": [ + "V_m", + "surrogate_gradient", + "learning_signal", + "V_th_adapt", + "adaptation", + ], # dynamic variables to record + "start": duration["offset_gen"] + duration["delay_in_rec"], # start time of recording + "stop": duration["offset_gen"] + duration["delay_in_rec"] + duration["task"], # stop time of recording +} + +params_mm_out = { + "interval": duration["step"], + "record_from": ["V_m", "readout_signal", "readout_signal_unnorm", "target_signal", "error_signal"], + "start": duration["total_offset"], + "stop": duration["total_offset"] + duration["task"], +} + +params_wr = { + "senders": nrns_in[:n_record_w] + nrns_rec[:n_record_w], # limit senders to subsample weights to record + "targets": nrns_rec[:n_record_w] + nrns_out, # limit targets to subsample weights to record from + "start": duration["total_offset"], + "stop": duration["total_offset"] + duration["task"], +} + +params_sr = { + "start": duration["total_offset"], + "stop": duration["total_offset"] + duration["task"], +} + +#################### + +mm_rec = nest.Create("multimeter", params_mm_rec) +mm_out = nest.Create("multimeter", params_mm_out) +sr = nest.Create("spike_recorder", params_sr) +wr = nest.Create("weight_recorder", params_wr) + +nrns_rec_record = nrns_rec[:n_record] + +# %% ########################################################################################################### +# Create connections +# ~~~~~~~~~~~~~~~~~~ +# Now, we define the connectivity and set up the synaptic parameters, with the synaptic weights drawn from +# normal distributions. After these preparations, we establish the enumerated connections of the core network, +# as well as additional connections to the recorders. + +params_conn_all_to_all = {"rule": "all_to_all", "allow_autapses": False} +params_conn_one_to_one = {"rule": "one_to_one"} + +dtype_weights = np.float32 # data type of weights - for reproducing TF results set to np.float32 +weights_in_rec = np.array(np.random.randn(n_in, n_rec).T / np.sqrt(n_in), dtype=dtype_weights) +weights_rec_rec = np.array(np.random.randn(n_rec, n_rec).T / np.sqrt(n_rec), dtype=dtype_weights) +np.fill_diagonal(weights_rec_rec, 0.0) # since no autapses set corresponding weights to zero +weights_rec_out = np.array(np.random.randn(n_rec, n_out).T / np.sqrt(n_rec), dtype=dtype_weights) +weights_out_rec = np.array(np.random.randn(n_rec, n_out) / np.sqrt(n_rec), dtype=dtype_weights) + +params_common_syn_eprop = { + "optimizer": { + "type": "adam", # algorithm to optimize the weights + "batch_size": n_batch, + "beta_1": 0.9, # exponential decay rate for 1st moment estimate of Adam optimizer + "beta_2": 0.999, # exponential decay rate for 2nd moment raw estimate of Adam optimizer + "epsilon": 1e-8, # small numerical stabilization constant of Adam optimizer + "eta": 5e-3, # learning rate + "Wmin": -100.0, # pA, minimal limit of the synaptic weights + "Wmax": 100.0, # pA, maximal limit of the synaptic weights + }, + "average_gradient": False, # if True, average the gradient over the learning window + "weight_recorder": wr, +} + +params_syn_base = { + "synapse_model": "eprop_synapse_bsshslm_2020", + "delay": duration["step"], # ms, dendritic delay + "tau_m_readout": params_nrn_out["tau_m"], # ms, for technical reasons pass readout neuron membrane time constant +} + +params_syn_in = params_syn_base.copy() +params_syn_in["weight"] = weights_in_rec # pA, initial values for the synaptic weights + +params_syn_rec = params_syn_base.copy() +params_syn_rec["weight"] = weights_rec_rec + +params_syn_out = params_syn_base.copy() +params_syn_out["weight"] = weights_rec_out + + +params_syn_feedback = { + "synapse_model": "eprop_learning_signal_connection_bsshslm_2020", + "delay": duration["step"], + "weight": weights_out_rec, +} + +params_syn_rate_target = { + "synapse_model": "rate_connection_delayed", + "delay": duration["step"], + "receptor_type": 2, # receptor type over which readout neuron receives target signal +} + +params_syn_static = { + "synapse_model": "static_synapse", + "delay": duration["step"], +} + +params_init_optimizer = { + "optimizer": { + "m": 0.0, # initial 1st moment estimate m of Adam optimizer + "v": 0.0, # initial 2nd moment raw estimate v of Adam optimizer + } +} + +#################### + +nest.SetDefaults("eprop_synapse_bsshslm_2020", params_common_syn_eprop) + +nest.Connect(gen_spk_in, nrns_in, params_conn_one_to_one, params_syn_static) # connection 1 +nest.Connect(nrns_in, nrns_rec, params_conn_all_to_all, params_syn_in) # connection 2 +nest.Connect(nrns_rec, nrns_rec, params_conn_all_to_all, params_syn_rec) # connection 3 +nest.Connect(nrns_rec, nrns_out, params_conn_all_to_all, params_syn_out) # connection 4 +nest.Connect(nrns_out, nrns_rec, params_conn_all_to_all, params_syn_feedback) # connection 5 +nest.Connect(gen_rate_target, nrns_out, params_conn_one_to_one, params_syn_rate_target) # connection 6 + +nest.Connect(nrns_in + nrns_rec, sr, params_conn_all_to_all, params_syn_static) + +nest.Connect(mm_rec, nrns_rec_record, params_conn_all_to_all, params_syn_static) +nest.Connect(mm_out, nrns_out, params_conn_all_to_all, params_syn_static) + +# After creating the connections, we can individually initialize the optimizer's +# dynamic variables for single synapses (here exemplarily for two connections). + +nest.GetConnections(nrns_rec[0], nrns_rec[1:3]).set([params_init_optimizer] * 2) + +# %% ########################################################################################################### +# Create input +# ~~~~~~~~~~~~ +# We generate some frozen Poisson spike noise of a fixed rate that is repeated in each iteration and feed these +# spike times to the previously created input spike generator. The network will use these spike times as a +# temporal backbone for encoding the target signal into its recurrent spiking activity. + +input_spike_prob = 0.05 # spike probability of frozen input noise +dtype_in_spks = np.float32 # data type of input spikes - for reproducing TF results set to np.float32 + +input_spike_bools = (np.random.rand(steps["sequence"], n_in) < input_spike_prob).swapaxes(0, 1) +input_spike_bools[:, 0] = 0 # remove spikes in 0th time step of every sequence for technical reasons + +sequence_starts = np.arange(0.0, duration["task"], duration["sequence"]) + duration["offset_gen"] +params_gen_spk_in = [] +for input_spike_bool in input_spike_bools: + input_spike_times = np.arange(0.0, duration["sequence"], duration["step"])[input_spike_bool] + input_spike_times_all = [input_spike_times + start for start in sequence_starts] + params_gen_spk_in.append({"spike_times": np.hstack(input_spike_times_all).astype(dtype_in_spks)}) + +#################### + +nest.SetStatus(gen_spk_in, params_gen_spk_in) + +# %% ########################################################################################################### +# Create output +# ~~~~~~~~~~~~~ +# Then, we load the x and y values of an image of the word "chaos" written by hand and construct a roughly +# one-second long target signal from it. This signal, like the input, is repeated for all iterations and fed +# into the rate generator that was previously created. + +target_signal_list = [ + np.sin(np.linspace(0.0, 2.0 * np.pi, steps["sequence"])), + np.sin(np.linspace(0.0, 4.0 * np.pi, steps["sequence"])), +] + +params_gen_rate_target = [] + +for target_signal in target_signal_list: + params_gen_rate_target.append( + { + "amplitude_times": np.arange(0.0, duration["task"], duration["step"]) + duration["total_offset"], + "amplitude_values": np.tile(target_signal, n_iter * n_batch), + } + ) + +#################### + +nest.SetStatus(gen_rate_target, params_gen_rate_target) + +# %% ########################################################################################################### +# Force final update +# ~~~~~~~~~~~~~~~~~~ +# Synapses only get active, that is, the correct weight update calculated and applied, when they transmit a +# spike. To still be able to read out the correct weights at the end of the simulation, we force spiking of the +# presynaptic neuron and thus an update of all synapses, including those that have not transmitted a spike in +# the last update interval, by sending a strong spike to all neurons that form the presynaptic side of an eprop +# synapse. This step is required purely for technical reasons. + +gen_spk_final_update = nest.Create("spike_generator", 1, {"spike_times": [duration["task"] + duration["delays"]]}) + +nest.Connect(gen_spk_final_update, nrns_in + nrns_rec, "all_to_all", {"weight": 1000.0}) + +# %% ########################################################################################################### +# Read out pre-training weights +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# Before we begin training, we read out the initial weight matrices so that we can eventually compare them to +# the optimized weights. + + +def get_weights(pop_pre, pop_post): + conns = nest.GetConnections(pop_pre, pop_post).get(["source", "target", "weight"]) + conns["senders"] = np.array(conns["source"]) - np.min(conns["source"]) + conns["targets"] = np.array(conns["target"]) - np.min(conns["target"]) + + conns["weight_matrix"] = np.zeros((len(pop_post), len(pop_pre))) + conns["weight_matrix"][conns["targets"], conns["senders"]] = conns["weight"] + return conns + + +weights_pre_train = { + "in_rec": get_weights(nrns_in, nrns_rec), + "rec_rec": get_weights(nrns_rec, nrns_rec), + "rec_out": get_weights(nrns_rec, nrns_out), +} + +# %% ########################################################################################################### +# Simulate +# ~~~~~~~~ +# We train the network by simulating for a set simulation time, determined by the number of iterations and the +# batch size and the length of one sequence. + +nest.Simulate(duration["sim"]) + +# %% ########################################################################################################### +# Read out post-training weights +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# After the training, we can read out the optimized final weights. + +weights_post_train = { + "in_rec": get_weights(nrns_in, nrns_rec), + "rec_rec": get_weights(nrns_rec, nrns_rec), + "rec_out": get_weights(nrns_rec, nrns_out), +} + +# %% ########################################################################################################### +# Read out recorders +# ~~~~~~~~~~~~~~~~~~ +# We can also retrieve the recorded history of the dynamic variables and weights, as well as detected spikes. + +events_mm_rec = mm_rec.get("events") +events_mm_out = mm_out.get("events") +events_sr = sr.get("events") +events_wr = wr.get("events") + +# %% ########################################################################################################### +# Evaluate training error +# ~~~~~~~~~~~~~~~~~~~~~~~ +# We evaluate the network's training error by calculating a loss - in this case, the mean squared error between +# the integrated recurrent network activity and the target rate. + +readout_signal = events_mm_out["readout_signal"] +target_signal = events_mm_out["target_signal"] +senders = events_mm_out["senders"] + +loss_list = [] +for sender in set(senders): + idc = senders == sender + error = (readout_signal[idc] - target_signal[idc]) ** 2 + loss_list.append(0.5 * np.add.reduceat(error, np.arange(0, steps["task"], steps["sequence"]))) + +loss = np.sum(loss_list, axis=0) + + +# %% ########################################################################################################### +# Plot results +# ~~~~~~~~~~~~ +# Then, we plot a series of plots. + +do_plotting = True # if True, plot the results + +if not do_plotting: + exit() + +colors = { + "blue": "#2854c5ff", + "red": "#e04b40ff", + "white": "#ffffffff", +} + +plt.rcParams.update( + { + "font.sans-serif": "Arial", + "axes.spines.right": False, + "axes.spines.top": False, + "axes.prop_cycle": cycler(color=[colors["blue"], colors["red"]]), + } +) + +# %% ########################################################################################################### +# Plot pattern +# ............ +# First, we visualize the created pattern and plot the target for comparison. The outputs of the two readout +# neurons encode the horizontal and vertical coordinate of the pattern respectively. + +fig, ax = plt.subplots() + +ax.plot( + readout_signal[senders == list(set(senders))[0]][-steps["sequence"] :], + -readout_signal[senders == list(set(senders))[1]][-steps["sequence"] :], + c=colors["red"], + label="readout", +) + +ax.plot( + target_signal[senders == list(set(senders))[0]][-steps["sequence"] :], + -target_signal[senders == list(set(senders))[1]][-steps["sequence"] :], + c=colors["blue"], + label="target", +) + +ax.set_xlabel(r"$y_0$ and $y^*_0$") +ax.set_ylabel(r"$y_1$ and $y^*_1$") + +ax.axis("equal") + +fig.tight_layout() + +# %% ########################################################################################################### +# Plot training error +# ................... +# We begin with a plot visualizing the training error of the network: the loss plotted against the iterations. + +fig, ax = plt.subplots() + +ax.plot(range(1, n_iter + 1), loss_list[0], label=r"$E_0$", alpha=0.8, c=colors["blue"], ls="--") +ax.plot(range(1, n_iter + 1), loss_list[1], label=r"$E_1$", alpha=0.8, c=colors["blue"], ls="dotted") +ax.plot(range(1, n_iter + 1), loss, label=r"$E$", c=colors["blue"]) +ax.set_ylabel(r"$E = \frac{1}{2} \sum_{t,k} \left( y_k^t -y_k^{*,t}\right)^2$") +ax.set_xlabel("training iteration") +ax.set_xlim(1, n_iter) +ax.xaxis.get_major_locator().set_params(integer=True) +ax.legend(bbox_to_anchor=(1.01, 0.5), loc="center left") +fig.tight_layout() + +# %% ########################################################################################################### +# Plot spikes and dynamic variables +# ................................. +# This plotting routine shows how to plot all of the recorded dynamic variables and spikes across time. We take +# one snapshot in the first iteration and one snapshot at the end. + + +def plot_recordable(ax, events, recordable, ylabel, xlims): + for sender in set(events["senders"]): + idc_sender = events["senders"] == sender + idc_times = (events["times"][idc_sender] > xlims[0]) & (events["times"][idc_sender] < xlims[1]) + ax.plot(events["times"][idc_sender][idc_times], events[recordable][idc_sender][idc_times], lw=0.5) + ax.set_ylabel(ylabel) + margin = np.abs(np.max(events[recordable]) - np.min(events[recordable])) * 0.1 + ax.set_ylim(np.min(events[recordable]) - margin, np.max(events[recordable]) + margin) + + +def plot_spikes(ax, events, nrns, ylabel, xlims): + idc_times = (events["times"] > xlims[0]) & (events["times"] < xlims[1]) + idc_sender = np.isin(events["senders"][idc_times], nrns.tolist()) + senders_subset = events["senders"][idc_times][idc_sender] + times_subset = events["times"][idc_times][idc_sender] + + ax.scatter(times_subset, senders_subset, s=0.1) + ax.set_ylabel(ylabel) + margin = np.abs(np.max(senders_subset) - np.min(senders_subset)) * 0.1 + ax.set_ylim(np.min(senders_subset) - margin, np.max(senders_subset) + margin) + + +for xlims in [(0, steps["sequence"]), (steps["task"] - steps["sequence"], steps["task"])]: + fig, axs = plt.subplots(12, 1, sharex=True, figsize=(8, 12), gridspec_kw={"hspace": 0.4, "left": 0.2}) + + plot_spikes(axs[0], events_sr, nrns_in, r"$z_i$" + "\n", xlims) + plot_spikes(axs[1], events_sr, nrns_rec, r"$z_j$" + "\n", xlims) + + plot_spikes(axs[3], events_sr, nrns_rec, r"$z_j$" + "\n", xlims) + + plot_recordable(axs[4], events_mm_rec, "V_m", r"$v_j$" + "\n(mV)", xlims) + plot_recordable(axs[5], events_mm_rec, "surrogate_gradient", r"$\psi_j$" + "\n", xlims) + plot_recordable(axs[6], events_mm_rec, "V_th_adapt", r"$A_j$" + "\n(mV)", xlims) + plot_recordable(axs[7], events_mm_rec, "learning_signal", r"$L_j$" + "\n(pA)", xlims) + + plot_recordable(axs[8], events_mm_out, "V_m", r"$v_k$" + "\n(mV)", xlims) + plot_recordable(axs[9], events_mm_out, "target_signal", r"$y^*_k$" + "\n", xlims) + plot_recordable(axs[10], events_mm_out, "readout_signal", r"$y_k$" + "\n", xlims) + plot_recordable(axs[11], events_mm_out, "error_signal", r"$y_k-y^*_k$" + "\n", xlims) + + axs[-1].set_xlabel(r"$t$ (ms)") + axs[-1].set_xlim(*xlims) + + fig.align_ylabels() + +# %% ########################################################################################################### +# Plot weight time courses +# ........................ +# Similarly, we can plot the weight histories. Note that the weight recorder, attached to the synapses, works +# differently than the other recorders. Since synapses only get activated when they transmit a spike, the weight +# recorder only records the weight in those moments. That is why the first weight registrations do not start in +# the first time step and we add the initial weights manually. + + +def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabel): + for sender in nrns_senders.tolist(): + for target in nrns_targets.tolist(): + idc_syn = (events["senders"] == sender) & (events["targets"] == target) + idc_syn_pre = (weights_pre_train[label]["source"] == sender) & ( + weights_pre_train[label]["target"] == target + ) + + times = [0.0] + events["times"][idc_syn].tolist() + weights = [weights_pre_train[label]["weight"][idc_syn_pre]] + events["weights"][idc_syn].tolist() + + ax.step(times, weights, c=colors["blue"]) + ax.set_ylabel(ylabel) + ax.set_ylim(-0.6, 0.6) + + +fig, axs = plt.subplots(3, 1, sharex=True, figsize=(3, 4)) + +plot_weight_time_course(axs[0], events_wr, nrns_in[:n_record_w], nrns_rec[:n_record_w], "in_rec", r"$W_\text{in}$ (pA)") +plot_weight_time_course( + axs[1], events_wr, nrns_rec[:n_record_w], nrns_rec[:n_record_w], "rec_rec", r"$W_\text{rec}$ (pA)" +) +plot_weight_time_course(axs[2], events_wr, nrns_rec[:n_record_w], nrns_out, "rec_out", r"$W_\text{out}$ (pA)") + +axs[-1].set_xlabel(r"$t$ (ms)") +axs[-1].set_xlim(0, steps["task"]) + +fig.align_ylabels() +fig.tight_layout() + +# %% ########################################################################################################### +# Plot weight matrices +# .................... +# If one is not interested in the time course of the weights, it is possible to read out only the initial and +# final weights, which requires less computing time and memory than the weight recorder approach. Here, we plot +# the corresponding weight matrices before and after the optimization. + +cmap = mpl.colors.LinearSegmentedColormap.from_list( + "cmap", ((0.0, colors["blue"]), (0.5, colors["white"]), (1.0, colors["red"])) +) + +fig, axs = plt.subplots(3, 2, sharex="col", sharey="row") + +all_w_extrema = [] + +for k in weights_pre_train.keys(): + w_pre = weights_pre_train[k]["weight"] + w_post = weights_post_train[k]["weight"] + all_w_extrema.append([np.min(w_pre), np.max(w_pre), np.min(w_post), np.max(w_post)]) + +args = {"cmap": cmap, "vmin": np.min(all_w_extrema), "vmax": np.max(all_w_extrema)} + +for i, weights in zip([0, 1], [weights_pre_train, weights_post_train]): + axs[0, i].pcolormesh(weights["in_rec"]["weight_matrix"].T, **args) + axs[1, i].pcolormesh(weights["rec_rec"]["weight_matrix"], **args) + cmesh = axs[2, i].pcolormesh(weights["rec_out"]["weight_matrix"], **args) + + axs[2, i].set_xlabel("recurrent\nneurons") + +axs[0, 0].set_ylabel("input\nneurons") +axs[1, 0].set_ylabel("recurrent\nneurons") +axs[2, 0].set_ylabel("readout\nneurons") +fig.align_ylabels(axs[:, 0]) + +axs[0, 0].text(0.5, 1.1, "pre-training", transform=axs[0, 0].transAxes, ha="center") +axs[0, 1].text(0.5, 1.1, "post-training", transform=axs[0, 1].transAxes, ha="center") + +axs[2, 0].yaxis.get_major_locator().set_params(integer=True) + +cbar = plt.colorbar(cmesh, cax=axs[1, 1].inset_axes([1.1, 0.2, 0.05, 0.8]), label="weight (pA)") + +fig.tight_layout() + +plt.show() diff --git a/pynest/examples/eprop_plasticity/eprop_supervised_regression_schematic_handwriting.png b/pynest/examples/eprop_plasticity/eprop_supervised_regression_schematic_handwriting.png new file mode 100644 index 0000000000..84ce96ed5e Binary files /dev/null and b/pynest/examples/eprop_plasticity/eprop_supervised_regression_schematic_handwriting.png differ diff --git a/pynest/examples/eprop_plasticity/eprop_supervised_regression_schematic_infinite-loop.png b/pynest/examples/eprop_plasticity/eprop_supervised_regression_schematic_infinite-loop.png new file mode 100644 index 0000000000..445510390a Binary files /dev/null and b/pynest/examples/eprop_plasticity/eprop_supervised_regression_schematic_infinite-loop.png differ diff --git a/pynest/examples/eprop_plasticity/eprop_supervised_regression_schematic_sine-waves.png b/pynest/examples/eprop_plasticity/eprop_supervised_regression_schematic_sine-waves.png new file mode 100644 index 0000000000..89e9d839fe Binary files /dev/null and b/pynest/examples/eprop_plasticity/eprop_supervised_regression_schematic_sine-waves.png differ diff --git a/pynest/examples/eprop_plasticity/eprop_supervised_regression_sine-waves.py b/pynest/examples/eprop_plasticity/eprop_supervised_regression_sine-waves.py new file mode 100644 index 0000000000..0f848a1390 --- /dev/null +++ b/pynest/examples/eprop_plasticity/eprop_supervised_regression_sine-waves.py @@ -0,0 +1,648 @@ +# -*- coding: utf-8 -*- +# +# eprop_supervised_regression_sine-waves.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +r""" +Tutorial on learning to generate sine waves with e-prop +------------------------------------------------------- + +Training a regression model using supervised e-prop plasticity to generate sine waves + +Description +~~~~~~~~~~~ + +This script demonstrates supervised learning of a regression task with a recurrent spiking neural network that +is equipped with the eligibility propagation (e-prop) plasticity mechanism by Bellec et al. [1]_. + +This type of learning is demonstrated at the proof-of-concept task in [1]_. We based this script on their +TensorFlow script given in [2]_. + +In this task, the network learns to generate an arbitrary N-dimensional temporal pattern. Here, the +network learns to reproduce with its overall spiking activity a one-dimensional, one-second-long target signal +which is a superposition of four sine waves of different amplitudes, phases, and periods. + +.. image:: ../../../../pynest/examples/eprop_plasticity/eprop_supervised_regression_schematic_sine-waves.png + :width: 70 % + :alt: See Figure 1 below. + :align: center + +Learning in the neural network model is achieved by optimizing the connection weights with e-prop plasticity. +This plasticity rule requires a specific network architecture depicted in Figure 1. The neural network model +consists of a recurrent network that receives frozen noise input from Poisson generators and projects onto one +readout neuron. The readout neuron compares the network signal :math:`y` with the teacher target signal +:math:`y*`, which it receives from a rate generator. In scenarios with multiple readout neurons, each individual +readout signal denoted as :math:`y_k` is compared with a corresponding target signal represented as +:math:`y_k^*`. The network's training error is assessed by employing a mean-squared error loss. + +Details on the event-based NEST implementation of e-prop can be found in [3]_. + +References +~~~~~~~~~~ + +.. [1] Bellec G, Scherr F, Subramoney F, Hajek E, Salaj D, Legenstein R, Maass W (2020). A solution to the + learning dilemma for recurrent networks of spiking neurons. Nature Communications, 11:3625. + https://doi.org/10.1038/s41467-020-17236-y + +.. [2] https://github.com/IGITUGraz/eligibility_propagation/blob/master/Figure_3_and_S7_e_prop_tutorials/tutorial_pattern_generation.py + +.. [3] Korcsak-Gorzo A, Stapmanns J, Espinoza Valverde JA, Dahmen D, van Albada SJ, Bolten M, Diesmann M. + Event-based implementation of eligibility propagation (in preparation) +""" # pylint: disable=line-too-long # noqa: E501 + +# %% ########################################################################################################### +# Import libraries +# ~~~~~~~~~~~~~~~~ +# We begin by importing all libraries required for the simulation, analysis, and visualization. + +import matplotlib as mpl +import matplotlib.pyplot as plt +import nest +import numpy as np +from cycler import cycler +from IPython.display import Image + +# %% ########################################################################################################### +# Schematic of network architecture +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# This figure, identical to the one in the description, shows the required network architecture in the center, +# the input and output of the pattern generation task above, and lists of the required NEST device, neuron, and +# synapse models below. The connections that must be established are numbered 1 to 6. + +try: + Image(filename="./eprop_supervised_regression_schematic_sine-waves.png") +except Exception: + pass + +# %% ########################################################################################################### +# Setup +# ~~~~~ + +# %% ########################################################################################################### +# Initialize random generator +# ........................... +# We seed the numpy random generator, which will generate random initial weights as well as random input and +# output. + +rng_seed = 1 # numpy random seed +np.random.seed(rng_seed) # fix numpy random seed + +# %% ########################################################################################################### +# Define timing of task +# ..................... +# The task's temporal structure is then defined, once as time steps and once as durations in milliseconds. + +n_batch = 1 # batch size, 1 in reference [2] +n_iter = 5 # number of iterations, 2000 in reference [2] + +steps = { + "sequence": 1000, # time steps of one full sequence +} + +steps["learning_window"] = steps["sequence"] # time steps of window with non-zero learning signals +steps["task"] = n_iter * n_batch * steps["sequence"] # time steps of task + +steps.update( + { + "offset_gen": 1, # offset since generator signals start from time step 1 + "delay_in_rec": 1, # connection delay between input and recurrent neurons + "delay_rec_out": 1, # connection delay between recurrent and output neurons + "delay_out_norm": 1, # connection delay between output neurons for normalization + "extension_sim": 1, # extra time step to close right-open simulation time interval in Simulate() + } +) + +steps["delays"] = steps["delay_in_rec"] + steps["delay_rec_out"] + steps["delay_out_norm"] # time steps of delays + +steps["total_offset"] = steps["offset_gen"] + steps["delays"] # time steps of total offset + +steps["sim"] = steps["task"] + steps["total_offset"] + steps["extension_sim"] # time steps of simulation + +duration = {"step": 1.0} # ms, temporal resolution of the simulation + +duration.update({key: value * duration["step"] for key, value in steps.items()}) # ms, durations + +# %% ########################################################################################################### +# Set up simulation +# ................. +# As last step of the setup, we reset the NEST kernel to remove all existing NEST simulation settings and +# objects and set some NEST kernel parameters, some of which are e-prop-related. + +params_setup = { + "eprop_learning_window": duration["learning_window"], + "eprop_reset_neurons_on_update": True, # if True, reset dynamic variables at start of each update interval + "eprop_update_interval": duration["sequence"], # ms, time interval for updating the synaptic weights + "print_time": False, # if True, print time progress bar during simulation, set False if run as code cell + "resolution": duration["step"], + "total_num_virtual_procs": 1, # number of virtual processes, set in case of distributed computing +} + +#################### + +nest.ResetKernel() +nest.set(**params_setup) + +# %% ########################################################################################################### +# Create neurons +# ~~~~~~~~~~~~~~ +# We proceed by creating a certain number of input, recurrent, and readout neurons and setting their parameters. +# Additionally, we already create an input spike generator and an output target rate generator, which we will +# configure later. + +n_in = 100 # number of input neurons +n_rec = 100 # number of recurrent neurons +n_out = 1 # number of readout neurons + +params_nrn_rec = { + "C_m": 1.0, # pF, membrane capacitance - takes effect only if neurons get current input (here not the case) + "c_reg": 300.0, # firing rate regularization scaling + "E_L": 0.0, # mV, leak / resting membrane potential + "f_target": 10.0, # spikes/s, target firing rate for firing rate regularization + "gamma": 0.3, # scaling of the pseudo derivative + "I_e": 0.0, # pA, external current input + "regular_spike_arrival": False, # If True, input spikes arrive at end of time step, if False at beginning + "surrogate_gradient_function": "piecewise_linear", # surrogate gradient / pseudo-derivative function + "t_ref": 0.0, # ms, duration of refractory period + "tau_m": 30.0, # ms, membrane time constant + "V_m": 0.0, # mV, initial value of the membrane voltage + "V_th": 0.03, # mV, spike threshold membrane voltage +} + +params_nrn_out = { + "C_m": 1.0, + "E_L": 0.0, + "I_e": 0.0, + "loss": "mean_squared_error", # loss function + "regular_spike_arrival": False, + "tau_m": 30.0, + "V_m": 0.0, +} + +#################### + +# Intermediate parrot neurons required between input spike generators and recurrent neurons, +# since devices cannot establish plastic synapses for technical reasons + +gen_spk_in = nest.Create("spike_generator", n_in) +nrns_in = nest.Create("parrot_neuron", n_in) + +# The suffix _bsshslm_2020 follows the NEST convention to indicate in the model name the paper +# that introduced it by the first letter of the authors' last names and the publication year. + +nrns_rec = nest.Create("eprop_iaf_bsshslm_2020", n_rec, params_nrn_rec) +nrns_out = nest.Create("eprop_readout_bsshslm_2020", n_out, params_nrn_out) +gen_rate_target = nest.Create("step_rate_generator", n_out) + + +# %% ########################################################################################################### +# Create recorders +# ~~~~~~~~~~~~~~~~ +# We also create recorders, which, while not required for the training, will allow us to track various dynamic +# variables of the neurons, spikes, and changes in synaptic weights. To save computing time and memory, the +# recorders, the recorded variables, neurons, and synapses can be limited to the ones relevant to the +# experiment, and the recording interval can be increased (see the documentation on the specific recorders). By +# default, recordings are stored in memory but can also be written to file. + +n_record = 1 # number of neurons to record dynamic variables from - this script requires n_record >= 1 +n_record_w = 3 # number of senders and targets to record weights from - this script requires n_record_w >=1 + +if n_record == 0 or n_record_w == 0: + raise ValueError("n_record and n_record_w >= 1 required") + +params_mm_rec = { + "interval": duration["step"], # interval between two recorded time points + "record_from": ["V_m", "surrogate_gradient", "learning_signal"], # dynamic variables to record + "start": duration["offset_gen"] + duration["delay_in_rec"], # start time of recording + "stop": duration["offset_gen"] + duration["delay_in_rec"] + duration["task"], # stop time of recording +} + +params_mm_out = { + "interval": duration["step"], + "record_from": ["V_m", "readout_signal", "readout_signal_unnorm", "target_signal", "error_signal"], + "start": duration["total_offset"], + "stop": duration["total_offset"] + duration["task"], +} + +params_wr = { + "senders": nrns_in[:n_record_w] + nrns_rec[:n_record_w], # limit senders to subsample weights to record + "targets": nrns_rec[:n_record_w] + nrns_out, # limit targets to subsample weights to record from + "start": duration["total_offset"], + "stop": duration["total_offset"] + duration["task"], +} + +params_sr = { + "start": duration["total_offset"], + "stop": duration["total_offset"] + duration["task"], +} + +#################### + +mm_rec = nest.Create("multimeter", params_mm_rec) +mm_out = nest.Create("multimeter", params_mm_out) +sr = nest.Create("spike_recorder", params_sr) +wr = nest.Create("weight_recorder", params_wr) + +nrns_rec_record = nrns_rec[:n_record] + +# %% ########################################################################################################### +# Create connections +# ~~~~~~~~~~~~~~~~~~ +# Now, we define the connectivity and set up the synaptic parameters, with the synaptic weights drawn from +# normal distributions. After these preparations, we establish the enumerated connections of the core network, +# as well as additional connections to the recorders. + +params_conn_all_to_all = {"rule": "all_to_all", "allow_autapses": False} +params_conn_one_to_one = {"rule": "one_to_one"} + +dtype_weights = np.float32 # data type of weights - for reproducing TF results set to np.float32 +weights_in_rec = np.array(np.random.randn(n_in, n_rec).T / np.sqrt(n_in), dtype=dtype_weights) +weights_rec_rec = np.array(np.random.randn(n_rec, n_rec).T / np.sqrt(n_rec), dtype=dtype_weights) +np.fill_diagonal(weights_rec_rec, 0.0) # since no autapses set corresponding weights to zero +weights_rec_out = np.array(np.random.randn(n_rec, n_out).T / np.sqrt(n_rec), dtype=dtype_weights) +weights_out_rec = np.array(np.random.randn(n_rec, n_out) / np.sqrt(n_rec), dtype=dtype_weights) + +params_common_syn_eprop = { + "optimizer": { + "type": "gradient_descent", # algorithm to optimize the weights + "batch_size": n_batch, + "eta": 1e-4, # learning rate + "Wmin": -100.0, # pA, minimal limit of the synaptic weights + "Wmax": 100.0, # pA, maximal limit of the synaptic weights + }, + "average_gradient": False, # if True, average the gradient over the learning window + "weight_recorder": wr, +} + +params_syn_base = { + "synapse_model": "eprop_synapse_bsshslm_2020", + "delay": duration["step"], # ms, dendritic delay + "tau_m_readout": params_nrn_out["tau_m"], # ms, for technical reasons pass readout neuron membrane time constant +} + +params_syn_in = params_syn_base.copy() +params_syn_in["weight"] = weights_in_rec # pA, initial values for the synaptic weights + +params_syn_rec = params_syn_base.copy() +params_syn_rec["weight"] = weights_rec_rec + +params_syn_out = params_syn_base.copy() +params_syn_out["weight"] = weights_rec_out + +params_syn_feedback = { + "synapse_model": "eprop_learning_signal_connection_bsshslm_2020", + "delay": duration["step"], + "weight": weights_out_rec, +} + +params_syn_rate_target = { + "synapse_model": "rate_connection_delayed", + "delay": duration["step"], + "receptor_type": 2, # receptor type over which readout neuron receives target signal +} + +params_syn_static = { + "synapse_model": "static_synapse", + "delay": duration["step"], +} + +#################### + +nest.SetDefaults("eprop_synapse_bsshslm_2020", params_common_syn_eprop) + +nest.Connect(gen_spk_in, nrns_in, params_conn_one_to_one, params_syn_static) # connection 1 +nest.Connect(nrns_in, nrns_rec, params_conn_all_to_all, params_syn_in) # connection 2 +nest.Connect(nrns_rec, nrns_rec, params_conn_all_to_all, params_syn_rec) # connection 3 +nest.Connect(nrns_rec, nrns_out, params_conn_all_to_all, params_syn_out) # connection 4 +nest.Connect(nrns_out, nrns_rec, params_conn_all_to_all, params_syn_feedback) # connection 5 +nest.Connect(gen_rate_target, nrns_out, params_conn_one_to_one, params_syn_rate_target) # connection 6 + +nest.Connect(nrns_in + nrns_rec, sr, params_conn_all_to_all, params_syn_static) + +nest.Connect(mm_rec, nrns_rec_record, params_conn_all_to_all, params_syn_static) +nest.Connect(mm_out, nrns_out, params_conn_all_to_all, params_syn_static) + +# %% ########################################################################################################### +# Create input +# ~~~~~~~~~~~~ +# We generate some frozen Poisson spike noise of a fixed rate that is repeated in each iteration and feed these +# spike times to the previously created input spike generator. The network will use these spike times as a +# temporal backbone for encoding the target signal into its recurrent spiking activity. + +input_spike_prob = 0.05 # spike probability of frozen input noise +dtype_in_spks = np.float32 # data type of input spikes - for reproducing TF results set to np.float32 + +input_spike_bools = (np.random.rand(steps["sequence"], n_in) < input_spike_prob).swapaxes(0, 1) +input_spike_bools[:, 0] = 0 # remove spikes in 0th time step of every sequence for technical reasons + +sequence_starts = np.arange(0.0, duration["task"], duration["sequence"]) + duration["offset_gen"] +params_gen_spk_in = [] +for input_spike_bool in input_spike_bools: + input_spike_times = np.arange(0.0, duration["sequence"], duration["step"])[input_spike_bool] + input_spike_times_all = [input_spike_times + start for start in sequence_starts] + params_gen_spk_in.append({"spike_times": np.hstack(input_spike_times_all).astype(dtype_in_spks)}) + +#################### + +nest.SetStatus(gen_spk_in, params_gen_spk_in) + +# %% ########################################################################################################### +# Create output +# ~~~~~~~~~~~~~ +# Then, as a superposition of four sine waves with various durations, amplitudes, and phases, we construct a +# one-second target signal. This signal, like the input, is repeated for all iterations and fed into the rate +# generator that was previously created. + + +def generate_superimposed_sines(steps_sequence, periods): + n_sines = len(periods) + + amplitudes = np.random.uniform(low=0.5, high=2.0, size=n_sines) + phases = np.random.uniform(low=0.0, high=2.0 * np.pi, size=n_sines) + + sines = [ + A * np.sin(np.linspace(phi, phi + 2.0 * np.pi * (steps_sequence // T), steps_sequence)) + for A, phi, T in zip(amplitudes, phases, periods) + ] + + superposition = sum(sines) + superposition -= superposition[0] + superposition /= max(np.abs(superposition).max(), 1e-6) + return superposition + + +target_signal = generate_superimposed_sines(steps["sequence"], [1000, 500, 333, 200]) # periods in steps + +params_gen_rate_target = { + "amplitude_times": np.arange(0.0, duration["task"], duration["step"]) + duration["total_offset"], + "amplitude_values": np.tile(target_signal, n_iter * n_batch), +} + +#################### + +nest.SetStatus(gen_rate_target, params_gen_rate_target) + +# %% ########################################################################################################### +# Force final update +# ~~~~~~~~~~~~~~~~~~ +# Synapses only get active, that is, the correct weight update calculated and applied, when they transmit a +# spike. To still be able to read out the correct weights at the end of the simulation, we force spiking of the +# presynaptic neuron and thus an update of all synapses, including those that have not transmitted a spike in +# the last update interval, by sending a strong spike to all neurons that form the presynaptic side of an eprop +# synapse. This step is required purely for technical reasons. + +gen_spk_final_update = nest.Create("spike_generator", 1, {"spike_times": [duration["task"] + duration["delays"]]}) + +nest.Connect(gen_spk_final_update, nrns_in + nrns_rec, "all_to_all", {"weight": 1000.0}) + +# %% ########################################################################################################### +# Read out pre-training weights +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# Before we begin training, we read out the initial weight matrices so that we can eventually compare them to +# the optimized weights. + + +def get_weights(pop_pre, pop_post): + conns = nest.GetConnections(pop_pre, pop_post).get(["source", "target", "weight"]) + conns["senders"] = np.array(conns["source"]) - np.min(conns["source"]) + conns["targets"] = np.array(conns["target"]) - np.min(conns["target"]) + + conns["weight_matrix"] = np.zeros((len(pop_post), len(pop_pre))) + conns["weight_matrix"][conns["targets"], conns["senders"]] = conns["weight"] + return conns + + +weights_pre_train = { + "in_rec": get_weights(nrns_in, nrns_rec), + "rec_rec": get_weights(nrns_rec, nrns_rec), + "rec_out": get_weights(nrns_rec, nrns_out), +} + +# %% ########################################################################################################### +# Simulate +# ~~~~~~~~ +# We train the network by simulating for a set simulation time, determined by the number of iterations and the +# batch size and the length of one sequence. + +nest.Simulate(duration["sim"]) + +# %% ########################################################################################################### +# Read out post-training weights +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# After the training, we can read out the optimized final weights. + +weights_post_train = { + "in_rec": get_weights(nrns_in, nrns_rec), + "rec_rec": get_weights(nrns_rec, nrns_rec), + "rec_out": get_weights(nrns_rec, nrns_out), +} + +# %% ########################################################################################################### +# Read out recorders +# ~~~~~~~~~~~~~~~~~~ +# We can also retrieve the recorded history of the dynamic variables and weights, as well as detected spikes. + +events_mm_rec = mm_rec.get("events") +events_mm_out = mm_out.get("events") +events_sr = sr.get("events") +events_wr = wr.get("events") + +# %% ########################################################################################################### +# Evaluate training error +# ~~~~~~~~~~~~~~~~~~~~~~~ +# We evaluate the network's training error by calculating a loss - in this case, the mean squared error between +# the integrated recurrent network activity and the target rate. + +readout_signal = events_mm_out["readout_signal"] +target_signal = events_mm_out["target_signal"] + +error = (readout_signal - target_signal) ** 2 +loss = 0.5 * np.add.reduceat(error, np.arange(0, steps["task"], steps["sequence"])) + +# %% ########################################################################################################### +# Plot results +# ~~~~~~~~~~~~ +# Then, we plot a series of plots. + +do_plotting = True # if True, plot the results + +if not do_plotting: + exit() + +colors = { + "blue": "#2854c5ff", + "red": "#e04b40ff", + "white": "#ffffffff", +} + +plt.rcParams.update( + { + "font.sans-serif": "Arial", + "axes.spines.right": False, + "axes.spines.top": False, + "axes.prop_cycle": cycler(color=[colors["blue"], colors["red"]]), + } +) + +# %% ########################################################################################################### +# Plot training error +# ................... +# We begin with a plot visualizing the training error of the network: the loss plotted against the iterations. + +fig, ax = plt.subplots() + +ax.plot(range(1, n_iter + 1), loss) +ax.set_ylabel(r"$E = \frac{1}{2} \sum_{t,k} \left( y_k^t -y_k^{*,t}\right)^2$") +ax.set_xlabel("training iteration") +ax.set_xlim(1, n_iter) +ax.xaxis.get_major_locator().set_params(integer=True) + +fig.tight_layout() + +# %% ########################################################################################################### +# Plot spikes and dynamic variables +# ................................. +# This plotting routine shows how to plot all of the recorded dynamic variables and spikes across time. We take +# one snapshot in the first iteration and one snapshot at the end. + + +def plot_recordable(ax, events, recordable, ylabel, xlims): + for sender in set(events["senders"]): + idc_sender = events["senders"] == sender + idc_times = (events["times"][idc_sender] > xlims[0]) & (events["times"][idc_sender] < xlims[1]) + ax.plot(events["times"][idc_sender][idc_times], events[recordable][idc_sender][idc_times], lw=0.5) + ax.set_ylabel(ylabel) + margin = np.abs(np.max(events[recordable]) - np.min(events[recordable])) * 0.1 + ax.set_ylim(np.min(events[recordable]) - margin, np.max(events[recordable]) + margin) + + +def plot_spikes(ax, events, nrns, ylabel, xlims): + idc_times = (events["times"] > xlims[0]) & (events["times"] < xlims[1]) + idc_sender = np.isin(events["senders"][idc_times], nrns.tolist()) + senders_subset = events["senders"][idc_times][idc_sender] + times_subset = events["times"][idc_times][idc_sender] + + ax.scatter(times_subset, senders_subset, s=0.1) + ax.set_ylabel(ylabel) + margin = np.abs(np.max(senders_subset) - np.min(senders_subset)) * 0.1 + ax.set_ylim(np.min(senders_subset) - margin, np.max(senders_subset) + margin) + + +for xlims in [(0, steps["sequence"]), (steps["task"] - steps["sequence"], steps["task"])]: + fig, axs = plt.subplots(9, 1, sharex=True, figsize=(6, 8), gridspec_kw={"hspace": 0.4, "left": 0.2}) + + plot_spikes(axs[0], events_sr, nrns_in, r"$z_i$" + "\n", xlims) + plot_spikes(axs[1], events_sr, nrns_rec, r"$z_j$" + "\n", xlims) + + plot_recordable(axs[2], events_mm_rec, "V_m", r"$v_j$" + "\n(mV)", xlims) + plot_recordable(axs[3], events_mm_rec, "surrogate_gradient", r"$\psi_j$" + "\n", xlims) + plot_recordable(axs[4], events_mm_rec, "learning_signal", r"$L_j$" + "\n(pA)", xlims) + + plot_recordable(axs[5], events_mm_out, "V_m", r"$v_k$" + "\n(mV)", xlims) + plot_recordable(axs[6], events_mm_out, "target_signal", r"$y^*_k$" + "\n", xlims) + plot_recordable(axs[7], events_mm_out, "readout_signal", r"$y_k$" + "\n", xlims) + plot_recordable(axs[8], events_mm_out, "error_signal", r"$y_k-y^*_k$" + "\n", xlims) + + axs[-1].set_xlabel(r"$t$ (ms)") + axs[-1].set_xlim(*xlims) + + fig.align_ylabels() + +# %% ########################################################################################################### +# Plot weight time courses +# ........................ +# Similarly, we can plot the weight histories. Note that the weight recorder, attached to the synapses, works +# differently than the other recorders. Since synapses only get activated when they transmit a spike, the weight +# recorder only records the weight in those moments. That is why the first weight registrations do not start in +# the first time step and we add the initial weights manually. + + +def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabel): + for sender in nrns_senders.tolist(): + for target in nrns_targets.tolist(): + idc_syn = (events["senders"] == sender) & (events["targets"] == target) + idc_syn_pre = (weights_pre_train[label]["source"] == sender) & ( + weights_pre_train[label]["target"] == target + ) + + times = [0.0] + events["times"][idc_syn].tolist() + weights = [weights_pre_train[label]["weight"][idc_syn_pre]] + events["weights"][idc_syn].tolist() + + ax.step(times, weights, c=colors["blue"]) + ax.set_ylabel(ylabel) + ax.set_ylim(-0.6, 0.6) + + +fig, axs = plt.subplots(3, 1, sharex=True, figsize=(3, 4)) + +plot_weight_time_course(axs[0], events_wr, nrns_in[:n_record_w], nrns_rec[:n_record_w], "in_rec", r"$W_\text{in}$ (pA)") +plot_weight_time_course( + axs[1], events_wr, nrns_rec[:n_record_w], nrns_rec[:n_record_w], "rec_rec", r"$W_\text{rec}$ (pA)" +) +plot_weight_time_course(axs[2], events_wr, nrns_rec[:n_record_w], nrns_out, "rec_out", r"$W_\text{out}$ (pA)") + +axs[-1].set_xlabel(r"$t$ (ms)") +axs[-1].set_xlim(0, steps["task"]) + +fig.align_ylabels() +fig.tight_layout() + +# %% ########################################################################################################### +# Plot weight matrices +# .................... +# If one is not interested in the time course of the weights, it is possible to read out only the initial and +# final weights, which requires less computing time and memory than the weight recorder approach. Here, we plot +# the corresponding weight matrices before and after the optimization. + +cmap = mpl.colors.LinearSegmentedColormap.from_list( + "cmap", ((0.0, colors["blue"]), (0.5, colors["white"]), (1.0, colors["red"])) +) + +fig, axs = plt.subplots(3, 2, sharex="col", sharey="row") + +all_w_extrema = [] + +for k in weights_pre_train.keys(): + w_pre = weights_pre_train[k]["weight"] + w_post = weights_post_train[k]["weight"] + all_w_extrema.append([np.min(w_pre), np.max(w_pre), np.min(w_post), np.max(w_post)]) + +args = {"cmap": cmap, "vmin": np.min(all_w_extrema), "vmax": np.max(all_w_extrema)} + +for i, weights in zip([0, 1], [weights_pre_train, weights_post_train]): + axs[0, i].pcolormesh(weights["in_rec"]["weight_matrix"].T, **args) + axs[1, i].pcolormesh(weights["rec_rec"]["weight_matrix"], **args) + cmesh = axs[2, i].pcolormesh(weights["rec_out"]["weight_matrix"], **args) + + axs[2, i].set_xlabel("recurrent\nneurons") + +axs[0, 0].set_ylabel("input\nneurons") +axs[1, 0].set_ylabel("recurrent\nneurons") +axs[2, 0].set_ylabel("readout\nneurons") +fig.align_ylabels(axs[:, 0]) + +axs[0, 0].text(0.5, 1.1, "pre-training", transform=axs[0, 0].transAxes, ha="center") +axs[0, 1].text(0.5, 1.1, "post-training", transform=axs[0, 1].transAxes, ha="center") + +axs[2, 0].yaxis.get_major_locator().set_params(integer=True) + +cbar = plt.colorbar(cmesh, cax=axs[1, 1].inset_axes([1.1, 0.2, 0.05, 0.8]), label="weight (pA)") + +fig.tight_layout() + +plt.show() diff --git a/pynest/nest/__init__.py b/pynest/nest/__init__.py index 398d1b82dd..0aa91fd22b 100644 --- a/pynest/nest/__init__.py +++ b/pynest/nest/__init__.py @@ -398,7 +398,21 @@ def __dir__(self): ), default=float("+inf"), ) - + eprop_update_interval = KernelAttribute( + "float", + ("Task-specific update interval of the e-prop plasticity mechanism [ms]."), + default=1000.0, + ) + eprop_learning_window = KernelAttribute( + "float", + ("Task-specific learning window of the e-prop plasticity mechanism [ms]."), + default=1000.0, + ) + eprop_reset_neurons_on_update = KernelAttribute( + "bool", + ("If True, reset dynamic variables of e-prop neurons upon e-prop update."), + default=True, + ) # Kernel attribute indices, used for fast lookup in `ll_api.py` _kernel_attr_names = builtins.set(k for k, v in vars().items() if isinstance(v, KernelAttribute)) _readonly_kernel_attrs = builtins.set( diff --git a/testsuite/pytests/sli2py_connect/test_common_properties_setting.py b/testsuite/pytests/sli2py_connect/test_common_properties_setting.py index 80d81ec548..2d1f2ee687 100644 --- a/testsuite/pytests/sli2py_connect/test_common_properties_setting.py +++ b/testsuite/pytests/sli2py_connect/test_common_properties_setting.py @@ -37,18 +37,18 @@ def set_volume_transmitter(): def set_default_delay_resolution(): - nest.resolution = nest.GetDefaults("eprop_synapse")["delay"] + nest.resolution = nest.GetDefaults("eprop_synapse_bsshslm_2020")["delay"] # This list shall contain all synapse models extending the CommonSynapseProperties class. # For each model, specify which parameter to test with and which test value to use. A # setup function can be provided if preparations are required. Provide also supported neuron model. common_prop_models = { - "eprop_synapse": { - "parameter": "batch_size", - "value": 10, + "eprop_synapse_bsshslm_2020": { + "parameter": "average_gradient", + "value": not nest.GetDefaults("eprop_synapse_bsshslm_2020")["average_gradient"], "setup": set_default_delay_resolution, - "neuron": "eprop_iaf_psc_delta", + "neuron": "eprop_iaf_bsshslm_2020", }, "jonke_synapse": {"parameter": "tau_plus", "value": 10, "setup": None, "neuron": "iaf_psc_alpha"}, "stdp_dopamine_synapse": { diff --git a/testsuite/pytests/sli2py_neurons/test_neurons_handle_multiplicity.py b/testsuite/pytests/sli2py_neurons/test_neurons_handle_multiplicity.py index e90075a99e..4490bc6485 100644 --- a/testsuite/pytests/sli2py_neurons/test_neurons_handle_multiplicity.py +++ b/testsuite/pytests/sli2py_neurons/test_neurons_handle_multiplicity.py @@ -90,8 +90,12 @@ } -def test_spike_multiplicity_parrot_neuron(): +@pytest.fixture(autouse=True) +def reset(): nest.ResetKernel() + + +def test_spike_multiplicity_parrot_neuron(): multiplicities = [1, 3, 2] spikes = [1.0, 2.0, 3.0] sg = nest.Create( @@ -122,8 +126,6 @@ def test_spike_multiplicity_parrot_neuron(): ], ) def test_spike_multiplicity(model): - nest.ResetKernel() - n1 = nest.Create(model) n2 = nest.Create(model) diff --git a/testsuite/pytests/sli2py_recording/test_multimeter_stepping.py b/testsuite/pytests/sli2py_recording/test_multimeter_stepping.py index 5929630174..6eb1d76b6c 100644 --- a/testsuite/pytests/sli2py_recording/test_multimeter_stepping.py +++ b/testsuite/pytests/sli2py_recording/test_multimeter_stepping.py @@ -29,6 +29,7 @@ import pytest skip_models = [ + "eprop_readout_bsshslm_2020", # extra timestep added to some recordables in update function "erfc_neuron", # binary neuron "ginzburg_neuron", # binary neuron "mcculloch_pitts_neuron", # binary neuron @@ -88,10 +89,9 @@ def build_net(model): """ nest.ResetKernel() - nrn = nest.Create(model) pg = nest.Create("poisson_generator", params={"rate": 1e4}) - mm = nest.Create("multimeter", {"interval": 0.1, "record_from": nrn.recordables}) + mm = nest.Create("multimeter", {"interval": nest.resolution, "record_from": nrn.recordables}) receptor_type = 0 if model in extra_params.keys(): diff --git a/testsuite/pytests/sli2py_regressions/test_issue_77.py b/testsuite/pytests/sli2py_regressions/test_issue_77.py index ad3601637e..aad0849b88 100644 --- a/testsuite/pytests/sli2py_regressions/test_issue_77.py +++ b/testsuite/pytests/sli2py_regressions/test_issue_77.py @@ -58,6 +58,9 @@ "music_rate_in_proxy", # MUSIC device "music_rate_out_proxy", # MUSIC device "astrocyte_lr_1994", # does not send spikes + "eprop_readout_bsshslm_2020", # does not send spikes + "eprop_iaf_bsshslm_2020", # does not support stdp synapses + "eprop_iaf_adapt_bsshslm_2020", # does not support stdp synapses ] # The following models require connections to rport 1 or other specific parameters: diff --git a/testsuite/pytests/sli2py_stimulating/test_dcgen_versus_I_e.py b/testsuite/pytests/sli2py_stimulating/test_dcgen_versus_I_e.py index bca3fe1ba4..f636ffc797 100644 --- a/testsuite/pytests/sli2py_stimulating/test_dcgen_versus_I_e.py +++ b/testsuite/pytests/sli2py_stimulating/test_dcgen_versus_I_e.py @@ -41,8 +41,7 @@ def test_dcgen_vs_I_e(model): # Models requiring special parameters if model in ["gif_psc_exp", "gif_cond_exp", "gif_psc_exp_multisynapse", "gif_cond_exp_multisynapse"]: nest.SetDefaults(model, params={"lambda_0": 0.0}) - - if model == "pp_psc_delta": + elif model == "pp_psc_delta": nest.SetDefaults(model, params={"c_2": 0.0}) # Create two neurons diff --git a/testsuite/pytests/test_eprop_bsshslm_2020_plasticity.py b/testsuite/pytests/test_eprop_bsshslm_2020_plasticity.py new file mode 100644 index 0000000000..0f65167d62 --- /dev/null +++ b/testsuite/pytests/test_eprop_bsshslm_2020_plasticity.py @@ -0,0 +1,749 @@ +# -*- coding: utf-8 -*- +# +# test_eprop_bsshslm_2020_plasticity.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +""" +Test functionality of e-prop plasticity. +""" + +import nest +import numpy as np +import pytest + +nest.set_verbosity("M_WARNING") + +supported_source_models = ["eprop_iaf_bsshslm_2020", "eprop_iaf_adapt_bsshslm_2020"] +supported_target_models = supported_source_models + ["eprop_readout_bsshslm_2020"] + + +@pytest.fixture(autouse=True) +def fix_resolution(): + nest.ResetKernel() + + +@pytest.mark.parametrize("source_model", supported_source_models) +@pytest.mark.parametrize("target_model", supported_target_models) +def test_connect_with_eprop_synapse(source_model, target_model): + """Ensures that the restriction to supported neuron models works.""" + + # Connect supported models with e-prop synapse + src = nest.Create(source_model) + tgt = nest.Create(target_model) + nest.Connect(src, tgt, "all_to_all", {"synapse_model": "eprop_synapse_bsshslm_2020", "delay": nest.resolution}) + + +@pytest.mark.parametrize("target_model", set(nest.node_models) - set(supported_target_models)) +def test_unsupported_model_raises(target_model): + """Confirm that connecting a non-eprop neuron as target via an eprop_synapse_bsshslm_2020 raises an error.""" + + src_nrn = nest.Create(supported_source_models[0]) + tgt_nrn = nest.Create(target_model) + + with pytest.raises(nest.kernel.NESTError): + nest.Connect(src_nrn, tgt_nrn, "all_to_all", {"synapse_model": "eprop_synapse_bsshslm_2020"}) + + +def test_eprop_regression(): + """ + Test correct computation of losses for a regression task + (for details on the task, see nest-simulator/pynest/examples/eprop_plasticity/eprop_supervised_regression_sine-waves.py) + by comparing the simulated losses with + + 1. NEST reference losses to catch scenarios in which the e-prop model does not work as intended (e.g., + potential future changes to the NEST code base or a faulty installation). These reference losses + were obtained from a simulation with the verified NEST e-prop implementation run with + Linux 4.15.0-213-generic, Python v3.11.6, Numpy v1.26.0, and NEST@3304c6b5c. + + 2. TensorFlow reference losses to check the faithfulness to the original model. These reference losses were + obtained from a simulation with the original TensorFlow implementation + (https://github.com/INM-6/eligibility_propagation/blob/eprop_in_nest/Figure_3_and_S7_e_prop_tutorials/tutorial_pattern_generation.py, + a modified fork of the original model at https://github.com/IGITUGraz/eligibility_propagation) run with + Linux 4.15.0-213-generic, Python v3.6.10, Numpy v1.18.0, TensorFlow v1.15.0, and + INM6/eligibility_propagation@7df7d2627. + """ # pylint: disable=line-too-long # noqa: E501 + + # Initialize random generator + rng_seed = 1 + np.random.seed(rng_seed) + + # Define timing of task + + n_batch = 1 + n_iter = 5 + + steps = { + "sequence": 1000, + } + + steps["learning_window"] = steps["sequence"] + steps["task"] = n_iter * n_batch * steps["sequence"] + + steps.update( + { + "offset_gen": 1, + "delay_in_rec": 1, + "delay_rec_out": 1, + "delay_out_norm": 1, + "extension_sim": 1, + } + ) + + steps["total_offset"] = ( + steps["offset_gen"] + steps["delay_in_rec"] + steps["delay_rec_out"] + steps["delay_out_norm"] + ) + + steps["sim"] = steps["task"] + steps["total_offset"] + steps["extension_sim"] + + duration = {"step": 1.0} + + duration.update({key: value * duration["step"] for key, value in steps.items()}) + + # Set up simulation + + params_setup = { + "eprop_learning_window": duration["learning_window"], + "eprop_reset_neurons_on_update": True, + "eprop_update_interval": duration["sequence"], + "print_time": False, + "resolution": duration["step"], + "total_num_virtual_procs": 1, + } + + nest.ResetKernel() + nest.set(**params_setup) + + # Create neurons + + n_in = 100 + n_rec = 100 + n_out = 1 + + params_nrn_rec = { + "C_m": 1.0, + "c_reg": 300.0, + "gamma": 0.3, + "E_L": 0.0, + "f_target": 10.0, + "I_e": 0.0, + "regular_spike_arrival": False, + "surrogate_gradient_function": "piecewise_linear", + "t_ref": 0.0, + "tau_m": 30.0, + "V_m": 0.0, + "V_th": 0.03, + } + + params_nrn_out = { + "C_m": 1.0, + "E_L": 0.0, + "I_e": 0.0, + "loss": "mean_squared_error", + "regular_spike_arrival": False, + "tau_m": 30.0, + "V_m": 0.0, + } + + gen_spk_in = nest.Create("spike_generator", n_in) + nrns_in = nest.Create("parrot_neuron", n_in) + nrns_rec = nest.Create("eprop_iaf_bsshslm_2020", n_rec, params_nrn_rec) + nrns_out = nest.Create("eprop_readout_bsshslm_2020", n_out, params_nrn_out) + gen_rate_target = nest.Create("step_rate_generator", n_out) + + # Create recorders + + n_record = 1 + n_record_w = 1 + + params_mm_rec = { + "record_from": ["V_m", "surrogate_gradient", "learning_signal"], + "start": duration["offset_gen"] + duration["delay_in_rec"], + "interval": duration["sequence"], + } + + params_mm_out = { + "record_from": ["V_m", "readout_signal", "readout_signal_unnorm", "target_signal", "error_signal"], + "start": duration["total_offset"], + "interval": duration["step"], + } + + params_wr = { + "senders": nrns_in[:n_record_w] + nrns_rec[:n_record_w], + "targets": nrns_rec[:n_record_w] + nrns_out, + } + + mm_rec = nest.Create("multimeter", params_mm_rec) + mm_out = nest.Create("multimeter", params_mm_out) + sr = nest.Create("spike_recorder") + wr = nest.Create("weight_recorder", params_wr) + + nrns_rec_record = nrns_rec[:n_record] + + # Create connections + + params_conn_all_to_all = {"rule": "all_to_all", "allow_autapses": False} + params_conn_one_to_one = {"rule": "one_to_one"} + + dtype_weights = np.float32 + weights_in_rec = np.array(np.random.randn(n_in, n_rec).T / np.sqrt(n_in), dtype=dtype_weights) + weights_rec_rec = np.array(np.random.randn(n_rec, n_rec).T / np.sqrt(n_rec), dtype=dtype_weights) + np.fill_diagonal(weights_rec_rec, 0.0) + weights_rec_out = np.array(np.random.randn(n_rec, n_out).T / np.sqrt(n_rec), dtype=dtype_weights) + weights_out_rec = np.array(np.random.randn(n_rec, n_out) / np.sqrt(n_rec), dtype=dtype_weights) + + params_common_syn_eprop = { + "optimizer": { + "type": "gradient_descent", + "batch_size": n_batch, + "eta": 1e-4, + "Wmin": -100.0, + "Wmax": 100.0, + }, + "weight_recorder": wr, + "average_gradient": False, + } + + params_syn_in = { + "synapse_model": "eprop_synapse_bsshslm_2020", + "delay": duration["step"], + "tau_m_readout": params_nrn_out["tau_m"], + "weight": weights_in_rec, + } + + params_syn_rec = { + "synapse_model": "eprop_synapse_bsshslm_2020", + "delay": duration["step"], + "tau_m_readout": params_nrn_out["tau_m"], + "weight": weights_rec_rec, + } + + params_syn_out = { + "synapse_model": "eprop_synapse_bsshslm_2020", + "delay": duration["step"], + "tau_m_readout": params_nrn_out["tau_m"], + "weight": weights_rec_out, + } + + params_syn_feedback = { + "synapse_model": "eprop_learning_signal_connection_bsshslm_2020", + "delay": duration["step"], + "weight": weights_out_rec, + } + + params_syn_rate_target = { + "synapse_model": "rate_connection_delayed", + "delay": duration["step"], + "receptor_type": 2, + } + + params_syn_static = { + "synapse_model": "static_synapse", + "delay": duration["step"], + } + + nest.SetDefaults("eprop_synapse_bsshslm_2020", params_common_syn_eprop) + + nest.Connect(gen_spk_in, nrns_in, params_conn_one_to_one, params_syn_static) + nest.Connect(nrns_in, nrns_rec, params_conn_all_to_all, params_syn_in) + nest.Connect(nrns_rec, nrns_rec, params_conn_all_to_all, params_syn_rec) + nest.Connect(nrns_rec, nrns_out, params_conn_all_to_all, params_syn_out) + nest.Connect(nrns_out, nrns_rec, params_conn_all_to_all, params_syn_feedback) + nest.Connect(gen_rate_target, nrns_out, params_conn_one_to_one, params_syn_rate_target) + + nest.Connect(nrns_in + nrns_rec, sr, params_conn_all_to_all, params_syn_static) + + nest.Connect(mm_rec, nrns_rec_record, params_conn_all_to_all, params_syn_static) + nest.Connect(mm_out, nrns_out, params_conn_all_to_all, params_syn_static) + + # Create input + + input_spike_prob = 0.05 + dtype_in_spks = np.float32 + + input_spike_bools = np.random.rand(n_batch, steps["sequence"], n_in) < input_spike_prob + input_spike_bools = np.hstack(input_spike_bools.swapaxes(1, 2)) + input_spike_bools[:, 0] = 0 + + sequence_starts = np.arange(0.0, duration["task"], duration["sequence"]) + duration["offset_gen"] + params_gen_spk_in = [] + for input_spike_bool in input_spike_bools: + input_spike_times = np.arange(0.0, duration["sequence"] * n_batch, duration["step"])[input_spike_bool] + input_spike_times_all = [input_spike_times + start for start in sequence_starts] + params_gen_spk_in.append({"spike_times": np.hstack(input_spike_times_all).astype(dtype_in_spks)}) + + nest.SetStatus(gen_spk_in, params_gen_spk_in) + + # Create output + + def generate_superimposed_sines(steps_sequence, periods): + n_sines = len(periods) + + amplitudes = np.random.uniform(low=0.5, high=2.0, size=n_sines) + phases = np.random.uniform(low=0.0, high=2.0 * np.pi, size=n_sines) + + sines = [ + A * np.sin(np.linspace(phi, phi + 2.0 * np.pi * (steps_sequence // T), steps_sequence)) + for A, phi, T in zip(amplitudes, phases, periods) + ] + + superposition = sum(sines) + superposition -= superposition[0] + superposition /= max(np.abs(superposition).max(), 1e-6) + return superposition + + target_signal = generate_superimposed_sines(steps["sequence"], [1000, 500, 333, 200]) + + params_gen_rate_target = { + "amplitude_times": np.arange(0.0, duration["task"], duration["step"]) + duration["total_offset"], + "amplitude_values": np.tile(target_signal, n_iter * n_batch), + } + + nest.SetStatus(gen_rate_target, params_gen_rate_target) + + # Simulate + + nest.Simulate(duration["sim"]) + + # Read out recorders + + events_mm_out = mm_out.get("events") + + # Evaluate training error + + readout_signal = events_mm_out["readout_signal"] + target_signal = events_mm_out["target_signal"] + + error = (readout_signal - target_signal) ** 2 + loss = 0.5 * np.add.reduceat(error, np.arange(0, steps["task"], steps["sequence"])) + + # Verify results + + loss_NEST_reference = np.array( + [ + 101.964356999041, + 103.466731126205, + 103.340607074771, + 103.680244037686, + 104.412775748752, + ] + ) + + loss_TF_reference = np.array( + [ + 101.964363098144, + 103.466735839843, + 103.340606689453, + 103.680244445800, + 104.412780761718, + ] + ) + + assert np.allclose(loss, loss_NEST_reference, rtol=1e-8) + assert np.allclose(loss, loss_TF_reference, rtol=1e-7) + + +def test_eprop_classification(): + """ + Test correct computation of losses for a classification task + (for details on the task, see nest-simulator/pynest/examples/eprop_plasticity/eprop_supervised_classification_evidence-accumulation.py) + by comparing the simulated losses with + + 1. NEST reference losses to catch scenarios in which the e-prop model does not work as intended (e.g., + potential future changes to the NEST code base or a faulty installation). These reference losses + were obtained from a simulation with the verified NEST e-prop implementation run with + Linux 4.15.0-213-generic, Python v3.11.6, Numpy v1.26.0, and NEST@3304c6b5c. + + 2. TensorFlow reference losses to check the faithfulness to the original model. These reference losses were + obtained from a simulation with the original TensorFlow implementation + (https://github.com/INM-6/eligibility_propagation/blob/eprop_in_nest/Figure_3_and_S7_e_prop_tutorials/tutorial_evidence_accumulation_with_alif.py, + a modified fork of the original model at https://github.com/IGITUGraz/eligibility_propagation) run with + Linux 4.15.0-213-generic, Python v3.6.10, Numpy v1.18.0, TensorFlow v1.15.0, and + INM6/eligibility_propagation@7df7d2627. + """ # pylint: disable=line-too-long # noqa: E501 + + # Initialize random generator + + rng_seed = 1 + np.random.seed(rng_seed) + + # Define timing of task + + n_batch = 1 + n_iter = 5 + + n_input_symbols = 4 + n_cues = 7 + prob_group = 0.3 + + steps = { + "cue": 100, + "spacing": 50, + "bg_noise": 1050, + "recall": 150, + } + + steps["cues"] = n_cues * (steps["cue"] + steps["spacing"]) + steps["sequence"] = steps["cues"] + steps["bg_noise"] + steps["recall"] + steps["learning_window"] = steps["recall"] + steps["task"] = n_iter * n_batch * steps["sequence"] + + steps.update( + { + "offset_gen": 1, + "delay_in_rec": 1, + "delay_rec_out": 1, + "delay_out_norm": 1, + "extension_sim": 1, + } + ) + + steps["total_offset"] = ( + steps["offset_gen"] + steps["delay_in_rec"] + steps["delay_rec_out"] + steps["delay_out_norm"] + ) + + steps["sim"] = steps["task"] + steps["total_offset"] + steps["extension_sim"] + + duration = {"step": 1.0} + + duration.update({key: value * duration["step"] for key, value in steps.items()}) + + # Set up simulation + + params_setup = { + "eprop_learning_window": duration["learning_window"], + "eprop_reset_neurons_on_update": True, + "eprop_update_interval": duration["sequence"], + "print_time": False, + "resolution": duration["step"], + "total_num_virtual_procs": 1, + } + + nest.ResetKernel() + nest.set(**params_setup) + + # Create neurons + + n_in = 40 + n_ad = 50 + n_reg = 50 + n_rec = n_ad + n_reg + n_out = 2 + + params_nrn_reg = { + "C_m": 1.0, + "c_reg": 2.0, + "E_L": 0.0, + "f_target": 10.0, + "gamma": 0.3, + "I_e": 0.0, + "regular_spike_arrival": True, + "surrogate_gradient_function": "piecewise_linear", + "t_ref": 5.0, + "tau_m": 20.0, + "V_m": 0.0, + "V_th": 0.6, + } + + params_nrn_ad = { + "adapt_tau": 2000.0, + "adaptation": 0.0, + "C_m": 1.0, + "c_reg": 2.0, + "E_L": 0.0, + "f_target": 10.0, + "gamma": 0.3, + "I_e": 0.0, + "regular_spike_arrival": True, + "surrogate_gradient_function": "piecewise_linear", + "t_ref": 5.0, + "tau_m": 20.0, + "V_m": 0.0, + "V_th": 0.6, + } + + params_nrn_ad["adapt_beta"] = ( + 1.7 * (1.0 - np.exp(-1.0 / params_nrn_ad["adapt_tau"])) / (1.0 - np.exp(-1.0 / params_nrn_ad["tau_m"])) + ) + + params_nrn_out = { + "C_m": 1.0, + "E_L": 0.0, + "I_e": 0.0, + "loss": "cross_entropy", + "regular_spike_arrival": False, + "tau_m": 20.0, + "V_m": 0.0, + } + + gen_spk_in = nest.Create("spike_generator", n_in) + nrns_in = nest.Create("parrot_neuron", n_in) + nrns_reg = nest.Create("eprop_iaf_bsshslm_2020", n_reg, params_nrn_reg) + nrns_ad = nest.Create("eprop_iaf_adapt_bsshslm_2020", n_ad, params_nrn_ad) + nrns_out = nest.Create("eprop_readout_bsshslm_2020", n_out, params_nrn_out) + gen_rate_target = nest.Create("step_rate_generator", n_out) + + nrns_rec = nrns_reg + nrns_ad + + # Create recorders + + n_record = 1 + n_record_w = 1 + + params_mm_rec = { + "record_from": ["V_m", "surrogate_gradient", "learning_signal"], + "start": duration["offset_gen"] + duration["delay_in_rec"], + "interval": duration["sequence"], + } + + params_mm_out = { + "record_from": ["V_m", "readout_signal", "readout_signal_unnorm", "target_signal", "error_signal"], + "start": duration["total_offset"], + "interval": duration["step"], + } + + params_wr = { + "senders": nrns_in[:n_record_w] + nrns_rec[:n_record_w], + "targets": nrns_rec[:n_record_w] + nrns_out, + } + + mm_rec = nest.Create("multimeter", params_mm_rec) + mm_out = nest.Create("multimeter", params_mm_out) + sr = nest.Create("spike_recorder") + wr = nest.Create("weight_recorder", params_wr) + + nrns_rec_record = nrns_rec[:n_record] + + # Create connections + + params_conn_all_to_all = {"rule": "all_to_all", "allow_autapses": False} + params_conn_one_to_one = {"rule": "one_to_one"} + + def calculate_glorot_dist(fan_in, fan_out): + glorot_scale = 1.0 / max(1.0, (fan_in + fan_out) / 2.0) + glorot_limit = np.sqrt(3.0 * glorot_scale) + glorot_distribution = np.random.uniform(low=-glorot_limit, high=glorot_limit, size=(fan_in, fan_out)) + return glorot_distribution + + dtype_weights = np.float32 + weights_in_rec = np.array(np.random.randn(n_in, n_rec).T / np.sqrt(n_in), dtype=dtype_weights) + weights_rec_rec = np.array(np.random.randn(n_rec, n_rec).T / np.sqrt(n_rec), dtype=dtype_weights) + np.fill_diagonal(weights_rec_rec, 0.0) + weights_rec_out = np.array(calculate_glorot_dist(n_rec, n_out).T, dtype=dtype_weights) + weights_out_rec = np.array(np.random.randn(n_rec, n_out), dtype=dtype_weights) + + params_common_syn_eprop = { + "optimizer": { + "type": "adam", + "batch_size": n_batch, + "beta_1": 0.9, + "beta_2": 0.999, + "epsilon": 1e-8, + "eta": 5e-3, + "Wmin": -100.0, + "Wmax": 100.0, + }, + "weight_recorder": wr, + "average_gradient": True, + } + + params_syn_in = { + "synapse_model": "eprop_synapse_bsshslm_2020", + "delay": duration["step"], + "tau_m_readout": params_nrn_out["tau_m"], + "weight": weights_in_rec, + } + + params_syn_rec = { + "synapse_model": "eprop_synapse_bsshslm_2020", + "delay": duration["step"], + "tau_m_readout": params_nrn_out["tau_m"], + "weight": weights_rec_rec, + } + + params_syn_out = { + "synapse_model": "eprop_synapse_bsshslm_2020", + "delay": duration["step"], + "tau_m_readout": params_nrn_out["tau_m"], + "weight": weights_rec_out, + } + + params_syn_feedback = { + "synapse_model": "eprop_learning_signal_connection_bsshslm_2020", + "delay": duration["step"], + "weight": weights_out_rec, + } + + params_syn_out_out = { + "synapse_model": "rate_connection_delayed", + "delay": duration["step"], + "receptor_type": 1, + "weight": 1.0, + } + + params_syn_rate_target = { + "synapse_model": "rate_connection_delayed", + "delay": duration["step"], + "receptor_type": 2, + } + + params_syn_static = { + "synapse_model": "static_synapse", + "delay": duration["step"], + } + + nest.SetDefaults("eprop_synapse_bsshslm_2020", params_common_syn_eprop) + + nest.Connect(gen_spk_in, nrns_in, params_conn_one_to_one, params_syn_static) + nest.Connect(nrns_in, nrns_rec, params_conn_all_to_all, params_syn_in) + nest.Connect(nrns_rec, nrns_rec, params_conn_all_to_all, params_syn_rec) + nest.Connect(nrns_rec, nrns_out, params_conn_all_to_all, params_syn_out) + nest.Connect(nrns_out, nrns_rec, params_conn_all_to_all, params_syn_feedback) + nest.Connect(gen_rate_target, nrns_out, params_conn_one_to_one, params_syn_rate_target) + nest.Connect(nrns_out, nrns_out, params_conn_all_to_all, params_syn_out_out) + + nest.Connect(nrns_in + nrns_rec, sr, params_conn_all_to_all, params_syn_static) + + nest.Connect(mm_rec, nrns_rec_record, params_conn_all_to_all, params_syn_static) + nest.Connect(mm_out, nrns_out, params_conn_all_to_all, params_syn_static) + + # Create input and output + + def generate_evidence_accumulation_input_output( + n_batch, n_in, prob_group, input_spike_prob, n_cues, n_input_symbols, steps + ): + n_pop_nrn = n_in // n_input_symbols + + prob_choices = np.array([prob_group, 1 - prob_group], dtype=np.float32) + idx = np.random.choice([0, 1], n_batch) + probs = np.zeros((n_batch, 2), dtype=np.float32) + probs[:, 0] = prob_choices[idx] + probs[:, 1] = prob_choices[1 - idx] + + batched_cues = np.zeros((n_batch, n_cues), dtype=int) + for b_idx in range(n_batch): + batched_cues[b_idx, :] = np.random.choice([0, 1], n_cues, p=probs[b_idx]) + + input_spike_probs = np.zeros((n_batch, steps["sequence"], n_in)) + + for b_idx in range(n_batch): + for c_idx in range(n_cues): + cue = batched_cues[b_idx, c_idx] + + step_start = c_idx * (steps["cue"] + steps["spacing"]) + steps["spacing"] + step_stop = step_start + steps["cue"] + + pop_nrn_start = cue * n_pop_nrn + pop_nrn_stop = pop_nrn_start + n_pop_nrn + + input_spike_probs[b_idx, step_start:step_stop, pop_nrn_start:pop_nrn_stop] = input_spike_prob + + input_spike_probs[:, -steps["recall"] :, 2 * n_pop_nrn : 3 * n_pop_nrn] = input_spike_prob + input_spike_probs[:, :, 3 * n_pop_nrn :] = input_spike_prob / 4.0 + input_spike_bools = input_spike_probs > np.random.rand(input_spike_probs.size).reshape(input_spike_probs.shape) + input_spike_bools[:, 0, :] = 0 + + target_cues = np.zeros(n_batch, dtype=int) + target_cues[:] = np.sum(batched_cues, axis=1) > int(n_cues / 2) + + return input_spike_bools, target_cues + + input_spike_prob = 0.04 + dtype_in_spks = np.float32 + + input_spike_bools_list = [] + target_cues_list = [] + + for iteration in range(n_iter): + input_spike_bools, target_cues = generate_evidence_accumulation_input_output( + n_batch, n_in, prob_group, input_spike_prob, n_cues, n_input_symbols, steps + ) + input_spike_bools_list.append(input_spike_bools) + target_cues_list.extend(target_cues.tolist()) + + input_spike_bools_arr = np.array(input_spike_bools_list).reshape(steps["task"], n_in) + timeline_task = np.arange(0.0, duration["task"], duration["step"]) + duration["offset_gen"] + + params_gen_spk_in = [ + {"spike_times": timeline_task[input_spike_bools_arr[:, nrn_in_idx]].astype(dtype_in_spks)} + for nrn_in_idx in range(n_in) + ] + + target_rate_changes = np.zeros((n_out, n_batch * n_iter)) + target_rate_changes[np.array(target_cues_list), np.arange(n_batch * n_iter)] = 1 + + params_gen_rate_target = [ + { + "amplitude_times": np.arange(0.0, duration["task"], duration["sequence"]) + duration["total_offset"], + "amplitude_values": target_rate_changes[nrn_out_idx], + } + for nrn_out_idx in range(n_out) + ] + + nest.SetStatus(gen_spk_in, params_gen_spk_in) + nest.SetStatus(gen_rate_target, params_gen_rate_target) + + # Simulate + + nest.Simulate(duration["sim"]) + + # Read out recorders + + events_mm_out = mm_out.get("events") + + # Evaluate training error + + readout_signal = events_mm_out["readout_signal"] + target_signal = events_mm_out["target_signal"] + senders = events_mm_out["senders"] + + readout_signal = np.array([readout_signal[senders == i] for i in set(senders)]) + target_signal = np.array([target_signal[senders == i] for i in set(senders)]) + + readout_signal = readout_signal.reshape((n_out, n_iter, n_batch, steps["sequence"])) + readout_signal = readout_signal[:, :, :, -steps["learning_window"] :] + + target_signal = target_signal.reshape((n_out, n_iter, n_batch, steps["sequence"])) + target_signal = target_signal[:, :, :, -steps["learning_window"] :] + + loss = -np.mean(np.sum(target_signal * np.log(readout_signal), axis=0), axis=(1, 2)) + + # Verify results + + loss_NEST_reference = np.array( + [ + 0.741152550006, + 0.740388187700, + 0.665785233177, + 0.663644193322, + 0.729428962844, + ] + ) + + loss_TF_reference = np.array( + [ + 0.741152524948, + 0.740388214588, + 0.665785133838, + 0.663644134998, + 0.729429066181, + ] + ) + + assert np.allclose(loss, loss_NEST_reference, rtol=1e-8) + assert np.allclose(loss, loss_TF_reference, rtol=1e-6) diff --git a/testsuite/pytests/test_labeled_synapses.py b/testsuite/pytests/test_labeled_synapses.py index 998a84bc71..77e4e2bfd0 100644 --- a/testsuite/pytests/test_labeled_synapses.py +++ b/testsuite/pytests/test_labeled_synapses.py @@ -37,6 +37,7 @@ class LabeledSynapsesTestCase(unittest.TestCase): def default_network(self, syn_model): nest.ResetKernel() + # set volume transmitter for stdp_dopamine_synapse_lbl vt = nest.Create("volume_transmitter", 3) nest.SetDefaults("stdp_dopamine_synapse", {"volume_transmitter": vt[0]}) @@ -56,6 +57,13 @@ def default_network(self, syn_model): self.urbanczik_synapses = ["urbanczik_synapse", "urbanczik_synapse_lbl", "urbanczik_synapse_hpc"] + self.eprop_synapses = ["eprop_synapse_bsshslm_2020", "eprop_synapse_bsshslm_2020_hpc"] + self.eprop_connections = [ + "eprop_learning_signal_connection_bsshslm_2020", + "eprop_learning_signal_connection_bsshslm_2020_lbl", + "eprop_learning_signal_connection_bsshslm_2020_hpc", + ] + # create neurons that accept all synapse connections (especially gap # junctions)... hh_psc_alpha_gap is only available with GSL, hence the # skipIf above @@ -80,6 +88,12 @@ def default_network(self, syn_model): syns = nest.GetDefaults("pp_cond_exp_mc_urbanczik")["receptor_types"] r_type = syns["soma_exc"] + if syn_model in self.eprop_synapses: + neurons = nest.Create("eprop_iaf_bsshslm_2020", 5) + + if syn_model in self.eprop_connections: + neurons = nest.Create("eprop_readout_bsshslm_2020", 5) + nest.Create("eprop_iaf_bsshslm_2020", 5) + return neurons, r_type def test_SetLabelToSynapseOnConnect(self): @@ -182,16 +196,37 @@ def test_SetLabelToNotLabeledSynapse(self): with self.assertRaises(nest.kernel.NESTError): nest.SetDefaults(syn, {"synapse_label": 123}) - # try set on connect - with self.assertRaises(nest.kernel.NESTError): + # plain connection + if syn in self.eprop_connections or syn in self.eprop_synapses: + # try set on connect + with self.assertRaises(nest.kernel.NESTError): + nest.Connect( + a[:2], + a[-2:], + {"rule": "one_to_one", "make_symmetric": symm}, + {"synapse_model": syn, "synapse_label": 123, "delay": nest.resolution}, + ) nest.Connect( - a, a, {"rule": "one_to_one", "make_symmetric": symm}, {"synapse_model": syn, "synapse_label": 123} + a[:2], + a[-2:], + {"rule": "one_to_one", "make_symmetric": symm}, + {"synapse_model": syn, "receptor_type": r_type, "delay": nest.resolution}, + ) + else: + # try set on connect + with self.assertRaises(nest.kernel.NESTError): + nest.Connect( + a, + a, + {"rule": "one_to_one", "make_symmetric": symm}, + {"synapse_model": syn, "synapse_label": 123}, + ) + nest.Connect( + a, + a, + {"rule": "one_to_one", "make_symmetric": symm}, + {"synapse_model": syn, "receptor_type": r_type}, ) - - # plain connection - nest.Connect( - a, a, {"rule": "one_to_one", "make_symmetric": symm}, {"synapse_model": syn, "receptor_type": r_type} - ) # try set on SetStatus c = nest.GetConnections(a, a) diff --git a/testsuite/pytests/test_multimeter.py b/testsuite/pytests/test_multimeter.py index 9cbcacece2..6e9fc3073c 100644 --- a/testsuite/pytests/test_multimeter.py +++ b/testsuite/pytests/test_multimeter.py @@ -84,8 +84,6 @@ def test_recordables_are_recorded(model): This test does not check if the data is meaningful. """ - nest.resolution = 2**-3 # Set to power of two to avoid rounding issues - recording_interval = 2 simtime = 10 num_data_expected = simtime / recording_interval - 1 diff --git a/testsuite/pytests/test_refractory.py b/testsuite/pytests/test_refractory.py index 98cbe46ccd..3050d9950b 100644 --- a/testsuite/pytests/test_refractory.py +++ b/testsuite/pytests/test_refractory.py @@ -57,6 +57,11 @@ neurons_interspike_ps = ["iaf_psc_alpha_ps", "iaf_psc_delta_ps", "iaf_psc_exp_ps"] +neurons_eprop = [ + "eprop_iaf_bsshslm_2020", + "eprop_iaf_adapt_bsshslm_2020", +] + # Models that first clamp the membrane potential at a higher value neurons_with_clamping = [ "aeif_psc_delta_clopath", @@ -79,6 +84,7 @@ "iaf_psc_exp_ps_lossless", # This one use presice times "siegert_neuron", # This one does not connect to voltmeter "step_rate_generator", # No regular neuron model + "eprop_readout_bsshslm_2020", # This one does not spike "iaf_tum_2000", # Hijacks the offset field, see #2912 ] @@ -92,14 +98,6 @@ } -# --------------------------------------------------------------------------- # -# Simulation time and refractory time limits -# --------------------------------------------------------------------------- # - -simtime = 100 -resolution = 0.1 - - # --------------------------------------------------------------------------- # # Test class # --------------------------------------------------------------------------- # @@ -113,7 +111,7 @@ class TestRefractoryCase(unittest.TestCase): def reset(self): nest.ResetKernel() - nest.resolution = resolution + nest.resolution = self.resolution nest.rng_seed = 123456 def compute_reftime(self, model, sr, vm, neuron): @@ -141,8 +139,8 @@ def compute_reftime(self, model, sr, vm, neuron): if model in neurons_interspike: # Spike emitted at next timestep so substract resolution - return spike_times[1] - spike_times[0] - resolution - elif model in neurons_interspike_ps: + return spike_times[1] - spike_times[0] - self.resolution + elif model in neurons_interspike_ps + neurons_eprop: return spike_times[1] - spike_times[0] else: Vr = nest.GetStatus(neuron, "V_reset")[0] @@ -158,7 +156,7 @@ def compute_reftime(self, model, sr, vm, neuron): # Find end of refractory period between 1st and 2nd spike idx_end = np.where(np.isclose(Vs[idx_spike:idx_max], Vr, 1e-6))[0][-1] - t_ref_sim = idx_end * resolution + t_ref_sim = idx_end * self.resolution return t_ref_sim @@ -167,20 +165,22 @@ def test_refractory_time(self): Check that refractory time implementation is correct. """ + simtime = 100 + for model in tested_models: + self.resolution = 0.1 + t_ref = 1.7 self.reset() if "t_ref" not in nest.GetDefaults(model): continue - # Randomly set a refractory period - t_ref = 1.7 # Create the neuron and devices nparams = {"t_ref": t_ref} neuron = nest.Create(model, params=nparams) name_Vm = "V_m.s" if model in mc_models else "V_m" - vm_params = {"interval": resolution, "record_from": [name_Vm]} + vm_params = {"interval": self.resolution, "record_from": [name_Vm]} vm = nest.Create("voltmeter", params=vm_params) sr = nest.Create("spike_recorder") cg = nest.Create("dc_generator", params={"amplitude": 1200.0}) @@ -188,7 +188,7 @@ def test_refractory_time(self): # For models that do not clamp V_m, use very large current to # trigger almost immediate spiking => t_ref almost equals # interspike - if model in neurons_interspike_ps: + if model in neurons_interspike_ps + neurons_eprop: nest.SetStatus(cg, "amplitude", 10000000.0) elif model == "ht_neuron": # ht_neuron use too long time with a very large amplitude @@ -210,7 +210,7 @@ def test_refractory_time(self): t_ref_sim = t_ref_sim - nest.GetStatus(neuron, "t_clamp")[0] # Approximate result for precise spikes (interpolation error) - if model in neurons_interspike_ps: + if model in neurons_interspike_ps + neurons_eprop: self.assertAlmostEqual( t_ref, t_ref_sim, diff --git a/testsuite/pytests/test_sp/test_disconnect.py b/testsuite/pytests/test_sp/test_disconnect.py index 4b0e57b8fe..40c22e19a8 100644 --- a/testsuite/pytests/test_sp/test_disconnect.py +++ b/testsuite/pytests/test_sp/test_disconnect.py @@ -69,12 +69,18 @@ def test_synapse_deletion_one_to_one_no_sp(self): for syn_model in nest.synapse_models: if syn_model not in self.exclude_synapse_model: nest.ResetKernel() - nest.resolution = 0.1 nest.total_num_virtual_procs = nest.num_processes - neurons = nest.Create("iaf_psc_alpha", 4) syn_dict = {"synapse_model": syn_model} + if "eprop_synapse_bsshslm_2020" in syn_model: + neurons = nest.Create("eprop_iaf_bsshslm_2020", 4) + syn_dict["delay"] = nest.resolution + elif "eprop_learning_signal_connection_bsshslm_2020" in syn_model: + neurons = nest.Create("eprop_readout_bsshslm_2020", 2) + nest.Create("eprop_iaf_bsshslm_2020", 2) + else: + neurons = nest.Create("iaf_psc_alpha", 4) + nest.Connect(neurons[0], neurons[2], "one_to_one", syn_dict) nest.Connect(neurons[1], neurons[3], "one_to_one", syn_dict) diff --git a/testsuite/pytests/test_sp/test_disconnect_multiple.py b/testsuite/pytests/test_sp/test_disconnect_multiple.py index 32d854ee60..e0529f1645 100644 --- a/testsuite/pytests/test_sp/test_disconnect_multiple.py +++ b/testsuite/pytests/test_sp/test_disconnect_multiple.py @@ -49,6 +49,11 @@ def setUp(self): "urbanczik_synapse", "urbanczik_synapse_lbl", "urbanczik_synapse_hpc", + "eprop_synapse_bsshslm_2020", + "eprop_synapse_bsshslm_2020_hpc", + "eprop_learning_signal_connection_bsshslm_2020", + "eprop_learning_signal_connection_bsshslm_2020_lbl", + "eprop_learning_signal_connection_bsshslm_2020_hpc", "sic_connection", ] diff --git a/testsuite/regressiontests/ticket-310.sli b/testsuite/regressiontests/ticket-310.sli index 23c45e5c95..424fc7f96d 100644 --- a/testsuite/regressiontests/ticket-310.sli +++ b/testsuite/regressiontests/ticket-310.sli @@ -27,7 +27,7 @@ * V_m to be set to >= V_th, and that they emit a spike with * time stamp == resolution in that case. * - * Hans Ekkehard Plesser, 2009-02-11 + * Hans Ekkehard Plesser, 2009-02-11 * */ @@ -37,10 +37,12 @@ % use power-of-two resolution to avoid roundof problems /res -3 dexp def -/* The following models will not be tested: - iaf_cxhk_2008 --- spikes only on explicit positive threshold crossings -*/ -/skip_list [ /iaf_chxk_2008 /correlospinmatrix_detector ] def +/skip_list [ /iaf_chxk_2008 % non-standard spiking conditions + /correlospinmatrix_detector % not a neuron + /eprop_iaf_bsshslm_2020 % no ArchivingNode, thus no t_spike + /eprop_iaf_adapt_bsshslm_2020 % no ArchivingNode, thus no t_spike + /eprop_readout_bsshslm_2020 % no ArchivingNode, thus no t_spike + ] def { GetKernelStatus /node_models get @@ -59,10 +61,10 @@ n << /V_m n /V_th get 15.0 add >> SetStatus res Simulate n /t_spike get res leq % works also for precise models - dup not { (FAILED: ) model cvs join == n ShowStatus } if + dup not { (FAILED: ) model cvs join == n ShowStatus } if } { true } - ifelse + ifelse } { true } ifelse @@ -72,7 +74,7 @@ % see if all entries are true true exch { and } Fold -} +} assert_or_die endusing diff --git a/testsuite/regressiontests/ticket-386.sli b/testsuite/regressiontests/ticket-386.sli index 1578c8fa14..b620443b72 100644 --- a/testsuite/regressiontests/ticket-386.sli +++ b/testsuite/regressiontests/ticket-386.sli @@ -81,4 +81,3 @@ def GetKernelStatus /node_models get { run_test } forall } pass_or_die - diff --git a/testsuite/regressiontests/ticket-421.sli b/testsuite/regressiontests/ticket-421.sli index a540e1d73b..88e3e452fe 100644 --- a/testsuite/regressiontests/ticket-421.sli +++ b/testsuite/regressiontests/ticket-421.sli @@ -30,7 +30,7 @@ Description: This test simulates all nodes providing V_m for a short while and checks that V_m remains constant. This is a minimal test against missing variable initializations, cf ticket #421. - + Remarks: - Passing this test does not mean that all variables are properly initialized. It may just catch some cases bad cases. - Simulator response to initialization errors is stochastic, so if variables are not initialized properly, this test may @@ -40,21 +40,22 @@ Remarks: The check if that model really initializes the membrane potential V_m to the steady-state value in absence of any input. If not, add the model to the exclude_models list below. -Author: Hans Ekkehard Plesser, 2010-05-05 +Author: Hans Ekkehard Plesser, 2010-05-05 */ (unittest) run /unittest using % models that should not be tested because they do not initialize V_m to -% steady state -/exclude_models [/aeif_cond_exp /aeif_cond_alpha /a2eif_cond_exp /a2eif_cond_exp_HW +% steady state or require special resolution +/exclude_models [/aeif_cond_exp /aeif_cond_alpha /a2eif_cond_exp /a2eif_cond_exp_HW /aeif_cond_alpha_multisynapse /aeif_psc_delta_clopath /aeif_cond_alpha_astro /aeif_psc_exp /aeif_psc_alpha /aeif_psc_delta /aeif_cond_beta_multisynapse /hh_cond_exp_traub /hh_cond_beta_gap_traub /hh_psc_alpha /hh_psc_alpha_clopath /hh_psc_alpha_gap /ht_neuron /ht_neuron_fs - /iaf_cond_exp_sfa_rr /izhikevich] def + /iaf_cond_exp_sfa_rr /izhikevich + /eprop_iaf_bsshslm_2020 /eprop_iaf_adapt_bsshslm_2020 /eprop_readout_bsshslm_2020] def -% use power-of-two resolution to avoid roundof problems +% use power-of-two resolution to avoid round-off problems /res -3 dexp def M_WARNING setverbosity @@ -82,12 +83,12 @@ M_WARNING setverbosity % check membrane potential for equality n /V_m get Vm0 sub abs 1e-13 lt - dup - { (*** OK: ) model cvs join ( *** ) join == } - { (###### FAIL : ) model cvs join - ( Vm0 = ) join Vm0 cvs join ( Vm = ) join n /V_m get cvs join == + dup + { (*** OK: ) model cvs join ( *** ) join == } + { (###### FAIL : ) model cvs join + ( Vm0 = ) join Vm0 cvs join ( Vm = ) join n /V_m get cvs join == } ifelse - } + } { true } ifelse } diff --git a/testsuite/regressiontests/ticket-618.sli b/testsuite/regressiontests/ticket-618.sli index c4078dae28..dd4d84d51f 100644 --- a/testsuite/regressiontests/ticket-618.sli +++ b/testsuite/regressiontests/ticket-618.sli @@ -26,7 +26,7 @@ Name: testsuite::ticket-618 - catch nodes which require tau_mem != tau_syn Synopsis: (ticket-618) run -> NEST exits if test fails -Description: +Description: All neuron models using exact integration require that tau_mem != tau_syn. This test ensures that all pertaining models raise an exception if tau_mem == tau_syn. @@ -37,7 +37,7 @@ or has a non-nan V_m after 10ms simulation. This test should be updated when alternative implementations of exact integration for the degenerate case are in place. - + Author: Hans Ekkehard Plesser, 2012-12-11 */ @@ -46,6 +46,8 @@ Author: Hans Ekkehard Plesser, 2012-12-11 M_ERROR setverbosity +/excluded_models [ /eprop_iaf_bsshslm_2020 /eprop_iaf_adapt_bsshslm_2020 /eprop_readout_bsshslm_2020 ] def + { GetKernelStatus /node_models get { @@ -57,13 +59,14 @@ M_ERROR setverbosity /result true def % pass by default % skip models without V_m - modelprops /V_m known - { + excluded_models model MemberQ not + modelprops /V_m known and + { % build dict setting all tau_* props to 10. /propdict << >> def modelprops keys - { - /key Set + { + /key Set key cvs length 4 geq { key cvs 0 4 getinterval (tau_) eq @@ -75,13 +78,13 @@ M_ERROR setverbosity % skip models without tau_* propdict empty not exch ; - { + { % the next line shall provoke an error for some % models mark - { + { /n model propdict Create def - } + } stopped { % we got an error, need to clean up @@ -95,25 +98,25 @@ M_ERROR setverbosity } { pop % mark - + % no error, simulate and check membrane potential is not nan - 10. Simulate - /result n /V_m get cvs (nan) neq def + 10. Simulate + /result n /V_m get cvs (nan) neq def } ifelse % stopped - } + } if % propdict empty not } if % /V_m known - result % leave result on stack + result % leave result on stack dup not { model == } if - } + } Map true exch { and } Fold - + } assert_or_die endusing