Skip to content

Commit

Permalink
Merge pull request #3160 from xavierotazuGDS/master
Browse files Browse the repository at this point in the history
Improve speed of dump_connections() in layer_impl.h
  • Loading branch information
heplesser authored Jun 24, 2024
2 parents 656e13d + dedffa4 commit 87ec23a
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 41 deletions.
77 changes: 38 additions & 39 deletions nestkernel/layer_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -310,53 +310,52 @@ Layer< D >::dump_connections( std::ostream& out,
AbstractLayerPTR target_layer,
const Token& syn_model )
{
std::vector< std::pair< Position< D >, size_t > >* src_vec = get_global_positions_vector( node_collection );

// Dictionary with parameters for get_connections()
// Find all connections for given sources, targets and synapse model
DictionaryDatum conn_filter( new Dictionary );
def( conn_filter, names::synapse_model, syn_model );
def( conn_filter, names::source, NodeCollectionDatum( node_collection ) );
def( conn_filter, names::target, NodeCollectionDatum( target_layer->get_node_collection() ) );
def( conn_filter, names::synapse_model, syn_model );
ArrayDatum connectome = kernel().connection_manager.get_connections( conn_filter );

// Avoid setting up new array for each iteration of the loop
std::vector< size_t > source_array( 1 );
// Get positions of remote nodes
std::vector< std::pair< Position< D >, size_t > >* src_vec = get_global_positions_vector( node_collection );

for ( typename std::vector< std::pair< Position< D >, size_t > >::iterator src_iter = src_vec->begin();
src_iter != src_vec->end();
++src_iter )
// Iterate over connectome and write every connection, looking up source position only if source neuron changes
size_t previous_source_node_id = 0; // dummy initial value, cannot be node_id of any node
Position< D > source_pos; // dummy value
for ( const auto& entry : connectome )
{
ConnectionDatum conn = getValue< ConnectionDatum >( entry );
const size_t source_node_id = conn.get_source_node_id();

const size_t source_node_id = src_iter->second;
const Position< D > source_pos = src_iter->first;

source_array[ 0 ] = source_node_id;
def( conn_filter, names::source, NodeCollectionDatum( NodeCollection::create( source_array ) ) );
ArrayDatum connectome = kernel().connection_manager.get_connections( conn_filter );

// Print information about all local connections for current source
for ( size_t i = 0; i < connectome.size(); ++i )
// Search source_pos for source node only if it is a different node
if ( source_node_id != previous_source_node_id )
{
ConnectionDatum con_id = getValue< ConnectionDatum >( connectome.get( i ) );
DictionaryDatum result_dict = kernel().connection_manager.get_synapse_status( con_id.get_source_node_id(),
con_id.get_target_node_id(),
con_id.get_target_thread(),
con_id.get_synapse_model_id(),
con_id.get_port() );

long target_node_id = getValue< long >( result_dict, names::target );
double weight = getValue< double >( result_dict, names::weight );
double delay = getValue< double >( result_dict, names::delay );

// Print source, target, weight, delay, rports
out << source_node_id << ' ' << target_node_id << ' ' << weight << ' ' << delay;

Layer< D >* tgt_layer = dynamic_cast< Layer< D >* >( target_layer.get() );

out << ' ';
const long tnode_lid = tgt_layer->node_collection_->get_nc_index( target_node_id );
assert( tnode_lid >= 0 );
tgt_layer->compute_displacement( source_pos, tnode_lid ).print( out );
out << '\n';
const auto it = std::find_if( src_vec->begin(),
src_vec->end(),
[ source_node_id ]( const std::pair< Position< D >, size_t >& p ) { return p.second == source_node_id; } );
assert( it != src_vec->end() ); // internal error if node not found

source_pos = it->first;
previous_source_node_id = source_node_id;
}

DictionaryDatum result_dict = kernel().connection_manager.get_synapse_status( source_node_id,
conn.get_target_node_id(),
conn.get_target_thread(),
conn.get_synapse_model_id(),
conn.get_port() );
const long target_node_id = getValue< long >( result_dict, names::target );
const double weight = getValue< double >( result_dict, names::weight );
const double delay = getValue< double >( result_dict, names::delay );
const Layer< D >* const tgt_layer = dynamic_cast< Layer< D >* >( target_layer.get() );
const long tnode_lid = tgt_layer->node_collection_->get_nc_index( target_node_id );
assert( tnode_lid >= 0 );

// Print source, target, weight, delay, rports
out << source_node_id << ' ' << target_node_id << ' ' << weight << ' ' << delay << ' ';
tgt_layer->compute_displacement( source_pos, tnode_lid ).print( out );
out << '\n';
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,13 @@ def compare_layers_and_connections(use_free_mask, tmp_path, mask_params, edge_wr

assert np.all(stored_src == src_layer_ref)
assert np.all(stored_target == target_layer_ref)
assert np.all(stored_connections == connections_ref)

# The order in which connections are written to file is implementation dependent. Therefore, we need to
# sort results and expectations here. We use lexsort to sort entire rows of the arrays. We need to
# transpose the array to provide it properly as keys to lexsort.
np.testing.assert_equal(
stored_connections[np.lexsort(stored_connections.T), :], connections_ref[np.lexsort(connections_ref.T), :]
)


@pytest.mark.parametrize(
Expand Down
3 changes: 2 additions & 1 deletion testsuite/pytests/test_spatial/test_weight_delay.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,9 @@ def test_layer_connections_dump(tmp_path, expected_conn_dump, layer_type):
fname = tmp_path / f"{layer_type}_layer_conns.txt"
nest.DumpLayerConnections(src_layer, tgt_layer, "static_synapse", fname)

# We need to sort results to be invariant against implementation-dependent output order
actual_conn_dump = fname.read_text().splitlines()
assert actual_conn_dump == expected_conn_dump
assert actual_conn_dump.sort() == expected_conn_dump.sort()


@pytest.fixture(scope="module")
Expand Down

0 comments on commit 87ec23a

Please sign in to comment.