diff --git a/nengo_spinnaker/builder/connection.py b/nengo_spinnaker/builder/connection.py index c132ce2..92bc4e5 100644 --- a/nengo_spinnaker/builder/connection.py +++ b/nengo_spinnaker/builder/connection.py @@ -20,10 +20,9 @@ def generic_sink_getter(model, conn): @Model.connection_parameter_builders.register(nengo.base.NengoObject) def build_generic_connection_params(model, conn): - transform = full_transform(conn) return BuiltConnection( decoders=None, - transform=transform, + transform=full_transform(conn, slice_pre=False), eval_points=None, solver_info=None ) diff --git a/nengo_spinnaker/node_io/ethernet.py b/nengo_spinnaker/node_io/ethernet.py index 00dd654..ed5a63e 100644 --- a/nengo_spinnaker/node_io/ethernet.py +++ b/nengo_spinnaker/node_io/ethernet.py @@ -111,8 +111,8 @@ def set_node_output(self, node, value): # Build an SDP packet to transmit for each outgoing connection for the # node for connection, (x, y, p) in self._node_outgoing[node]: - # Perform connection function and transform - c_value = value[:] + # Apply the pre-slice, the connection function and the transform. + c_value = value[connection.pre_slice] if connection.function is not None: c_value = connection.function(c_value) c_value = np.dot(connection.transform, c_value) diff --git a/nengo_spinnaker/operators/value_source.py b/nengo_spinnaker/operators/value_source.py index 964c8f2..32fff75 100644 --- a/nengo_spinnaker/operators/value_source.py +++ b/nengo_spinnaker/operators/value_source.py @@ -6,6 +6,7 @@ import struct from nengo.processes import Process +from nengo.utils import numpy as npext from nengo_spinnaker.builder.builder import OutputPort, netlistspec from nengo_spinnaker.netlist import VertexSlice @@ -50,9 +51,9 @@ def make_vertices(self, model, n_steps): # Add the keys for this connection conn = conns[0] - so = conns[0].size_out keys.extend(list( - get_derived_keyspaces(sig.keyspace, slice(0, so)) + get_derived_keyspaces(sig.keyspace, conn.post_slice, + max_v=conn.post_obj.size_in) )) self.conns.append(conn) size_out = len(keys) @@ -140,16 +141,31 @@ def before_simulation(self, netlist, simulator, n_steps): else: values = np.array([self.function for t in ts]) + # Ensure that the values can be sliced, regardless of how they were + # generated. + values = npext.array(values, min_dims=2) + # Compute the output for each connection outputs = [] for conn in self.conns: output = [] + + # For each f(t) for the next set of simulations we calculate the + # output at the end of the connection. To do this we first apply + # the pre-slice, then the function and then the post-slice. for v in values: + # Apply the pre-slice + v = v[conn.pre_slice] + + # Apply the function on the connection, if there is one. if conn.function is not None: v = conn.function(v) + output.append(np.dot(conn.transform, v.T)) outputs.append(np.array(output).reshape(n_steps, conn.size_out)) + # Combine all of the output values to form a large matrix which we can + # dump into memory. output_matrix = np.hstack(outputs) new_output_region = regions.MatrixRegion( diff --git a/regression-tests/test_nodes_sliced.py b/regression-tests/test_nodes_sliced.py new file mode 100644 index 0000000..d9f5560 --- /dev/null +++ b/regression-tests/test_nodes_sliced.py @@ -0,0 +1,58 @@ +"""More complex function of time Node example. +""" +import nengo +import nengo_spinnaker +import numpy as np +import pytest + + +@pytest.mark.parametrize("f_of_t", [True, False]) +def test_nodes_sliced(f_of_t): + # Create a model with a single function of time node which returns a 4D + # vector, apply preslicing on some connections from it and ensure that this + # slicing plays nicely with the functions attached to the connections. + def out_fun_1(val): + assert val.size == 2 + return val * 2 + + with nengo.Network() as model: + # Create the input node and an ensemble + in_node = nengo.Node(lambda t: [0.1, 1.0, 0.2, -1.0], size_out=4) + in_node_2 = nengo.Node(0.25) + + ens = nengo.Ensemble(400, 4) + ens2 = nengo.Ensemble(200, 2) + + # Create the connections + nengo.Connection(in_node[::2], ens[[1, 3]], transform=.5, + function=out_fun_1) + nengo.Connection(in_node_2[[0, 0]], ens2) + + # Probe the ensemble to ensure that the values are correct + p = nengo.Probe(ens, synapse=0.05) + p2 = nengo.Probe(ens2, synapse=0.05) + + # Mark the input as being a function of time if desired + if f_of_t: + nengo_spinnaker.add_spinnaker_params(model.config) + model.config[in_node].function_of_time = True + + # Run the simulator for 1.0 s and check that the last probed values are in + # range + sim = nengo_spinnaker.Simulator(model) + with sim: + sim.run(1.0) + + # Check the final values + assert -0.05 < sim.data[p][-1, 0] < 0.05 + assert 0.05 < sim.data[p][-1, 1] < 0.15 + assert -0.05 < sim.data[p][-1, 2] < 0.05 + assert 0.15 < sim.data[p][-1, 3] < 0.25 + + assert 0.20 < sim.data[p2][-1, 0] < 0.30 + assert 0.20 < sim.data[p2][-1, 1] < 0.30 + + +if __name__ == "__main__": + test_nodes_sliced(True) + test_nodes_sliced(False) diff --git a/tests/builder/test_connection.py b/tests/builder/test_connection.py index 8903233..6bf649d 100644 --- a/tests/builder/test_connection.py +++ b/tests/builder/test_connection.py @@ -66,6 +66,6 @@ def test_build_standard_connection_params(): # Build the connection parameters params = build_generic_connection_params(None, a_b) assert params.decoders is None - assert np.all(params.transform == [[1.0, 0.0]]) + assert params.transform == 1.0 assert params.eval_points is None assert params.solver_info is None