From df6ac904c70ba1ac2d118537ff2bd40394602a87 Mon Sep 17 00:00:00 2001 From: Daniel Jones Date: Tue, 13 Feb 2024 22:02:27 +0000 Subject: [PATCH] Add NearestNeighbour unit test --- tests/test_kdtree.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/test_kdtree.py b/tests/test_kdtree.py index 2f625002..db00d139 100644 --- a/tests/test_kdtree.py +++ b/tests/test_kdtree.py @@ -1,8 +1,9 @@ import numpy as np import pytest -from signalflow import KDTree +from signalflow import KDTree, NearestNeighbour, Buffer +from . import graph -def test_kdtree(): +def test_kdtree(graph): corpus = np.random.uniform(0, 10, [1024, 2]) tree = KDTree(corpus) target = np.array([1, 5]) @@ -23,6 +24,13 @@ def test_kdtree(): # (subject to rounding error) assert nearest.distance == pytest.approx(np.linalg.norm(corpus[nearest.index] - target), abs=1e-6) + # confirm that the NearestNeighbour node returns the same value, + # constructing a buffer storing the corpus coordinates + buffer = Buffer(corpus.T) + nearest_neighbour = NearestNeighbour(buffer, target=target) + graph.render_subgraph(nearest_neighbour) + assert nearest_neighbour.output_buffer[0][0] == nearest.index + def test_kdtree_validation(): corpus = np.random.uniform(0, 10, [1024, 2]) tree = KDTree(corpus)