diff --git a/tests/test_kdtree.py b/tests/test_kdtree.py index 2f625002..db00d139 100644 --- a/tests/test_kdtree.py +++ b/tests/test_kdtree.py @@ -1,8 +1,9 @@ import numpy as np import pytest -from signalflow import KDTree +from signalflow import KDTree, NearestNeighbour, Buffer +from . import graph -def test_kdtree(): +def test_kdtree(graph): corpus = np.random.uniform(0, 10, [1024, 2]) tree = KDTree(corpus) target = np.array([1, 5]) @@ -23,6 +24,13 @@ def test_kdtree(): # (subject to rounding error) assert nearest.distance == pytest.approx(np.linalg.norm(corpus[nearest.index] - target), abs=1e-6) + # confirm that the NearestNeighbour node returns the same value, + # constructing a buffer storing the corpus coordinates + buffer = Buffer(corpus.T) + nearest_neighbour = NearestNeighbour(buffer, target=target) + graph.render_subgraph(nearest_neighbour) + assert nearest_neighbour.output_buffer[0][0] == nearest.index + def test_kdtree_validation(): corpus = np.random.uniform(0, 10, [1024, 2]) tree = KDTree(corpus)