diff --git a/nengo_spinnaker/builder/node.py b/nengo_spinnaker/builder/node.py index 0679aed..ea9fc43 100644 --- a/nengo_spinnaker/builder/node.py +++ b/nengo_spinnaker/builder/node.py @@ -144,7 +144,8 @@ def get_node_source(self, model, cn): # reference to the value source we created earlier. return spec(ObjectPort(self._f_of_t_nodes[cn.pre_obj], OutputPort.standard)) - elif type(cn.post_obj) is nengo.Node: + elif (type(cn.post_obj) is nengo.Node and + cn.post_obj not in self._passthrough_nodes): # If this connection goes from a Node to another Node (exactly, not # any subclasses) then we just add both nodes and the connection to # the host model. @@ -190,7 +191,8 @@ def get_node_sink(self, model, cn): # to the Filter operator we created earlier regardless. return spec(ObjectPort(self._passthrough_nodes[cn.post_obj], InputPort.standard)) - elif type(cn.pre_obj) is nengo.Node: + elif (type(cn.pre_obj) is nengo.Node and + cn.pre_obj not in self._passthrough_nodes): # If this connection goes from a Node to another Node (exactly, not # any subclasses) then we just add both nodes and the connection to # the host model. diff --git a/regression-tests/test_passnodes.py b/regression-tests/test_passnodes.py index 05d05c7..27b96e5 100644 --- a/regression-tests/test_passnodes.py +++ b/regression-tests/test_passnodes.py @@ -9,6 +9,15 @@ def test_probe_passnodes(): """Test that pass nodes are left on SpiNNaker and that they may be probed. """ + class ValueReceiver(object): + def __init__(self): + self.ts = list() + self.values = list() + + def __call__(self, t, x): + self.ts.append(t) + self.values.append(x[:]) + with nengo.Network("Test Network") as net: # Create an input Node which is a function of time only input_node = nengo.Node(lambda t: -0.33 if t < 1.0 else 0.10, @@ -22,6 +31,12 @@ def test_probe_passnodes(): transform=[[1.0], [0.0], [-1.0]]) p_ens = nengo.Probe(ens.output, synapse=0.05) + # Also add a node connected to the end of the ensemble array to ensure + # that multiple things correctly receive values from the filter. + receiver = ValueReceiver() + n_receiver = nengo.Node(receiver, size_in=3) + nengo.Connection(ens.output, n_receiver, synapse=0.05) + # Mark the input Node as being a function of time nengo_spinnaker.add_spinnaker_params(net.config) net.config[input_node].function_of_time = True @@ -50,6 +65,9 @@ def test_probe_passnodes(): np.all(-0.05 >= data[index20:, 2]) and np.all(-0.15 <= data[index20:, 2])) + # Check that values came into the node correctly + assert +0.05 <= receiver.values[-1][0] <= +0.15 + assert -0.05 >= receiver.values[-1][2] >= -0.15 if __name__ == "__main__": test_probe_passnodes() diff --git a/tests/builder/test_node.py b/tests/builder/test_node.py index 3514fb5..88bd525 100644 --- a/tests/builder/test_node.py +++ b/tests/builder/test_node.py @@ -365,12 +365,22 @@ def test_passthrough_nodes_with_other_nodes(self): assert spec.target.obj is model.object_operators[b] assert spec.target.port is OutputPort.standard + # Get the sink and ensure that the appropriate object is returned + with mock.patch.object(nioc, "get_spinnaker_sink_for_node"): + assert nioc.get_node_sink(model, b_c) is not None + assert c in nioc._input_nodes + # Get the sink and ensure that the appropriate object is returned with mock.patch.object(nioc, "get_spinnaker_sink_for_node") as gssfn: spec = nioc.get_node_sink(model, a_b) assert spec.target.obj is model.object_operators[b] assert spec.target.port is InputPort.standard + # Get the source and ensure that the appropriate object is returned + with mock.patch.object(nioc, "get_spinnaker_source_for_node") as gssfn: + assert nioc.get_node_source(model, a_b) is not None + assert a in nioc._output_nodes + def test_get_node_sink_standard(self): """Test that calling a NodeIOController to get the sink for a connection which terminates at a Node calls the method