From 05a336ea69efb5e8c9f99d0424811154834ec665 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Wed, 15 Nov 2023 12:56:15 -0500 Subject: [PATCH] Utilize exact kNN search when gathering k > numVectors in a segment (#12806) When requesting for k >= numVectors, it doesn't make sense to go through the HNSW graph. Even without a user supplied filter, we should not explore the HNSW graph if it contains fewer than k vectors. One scenario where we may still explore the graph if k >= numVectors is when not every document has a vector and there are deleted docs. But, this commit significantly improves things regardless. --- lucene/CHANGES.txt | 2 + .../lucene99/Lucene99HnswVectorsReader.java | 44 ++++++++++++++----- .../search/BaseKnnVectorQueryTestCase.java | 14 +++++- 3 files changed, 47 insertions(+), 13 deletions(-) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index c3a62910e9b8..4bc30f2058bd 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -285,6 +285,8 @@ Optimizations * GITHUB#12784: Cache buckets to speed up BytesRefHash#sort. (Guo Feng) +* GITHUB#12806: Utilize exact kNN search when gathering k >= numVectors in a segment (Ben Trent) + Changes in runtime behavior --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java index fb9a4bb550fe..140477cf749f 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java @@ -227,12 +227,22 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) { return; } - RandomVectorScorer scorer = flatVectorsReader.getRandomVectorScorer(field, target); - HnswGraphSearcher.search( - scorer, - new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc), - getGraph(fieldEntry), - scorer.getAcceptOrds(acceptDocs)); + final RandomVectorScorer scorer = flatVectorsReader.getRandomVectorScorer(field, target); + final KnnCollector collector = + new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc); + final Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs); + if (knnCollector.k() < scorer.maxOrd()) { + HnswGraphSearcher.search(scorer, collector, getGraph(fieldEntry), acceptedOrds); + } else { + // if k is larger than the number of vectors, we can just iterate over all vectors + // and collect them + for (int i = 0; i < scorer.maxOrd(); i++) { + if (acceptedOrds == null || acceptedOrds.get(i)) { + knnCollector.incVisitedCount(1); + knnCollector.collect(scorer.ordToDoc(i), scorer.score(i)); + } + } + } } @Override @@ -245,12 +255,22 @@ public void search(String field, byte[] target, KnnCollector knnCollector, Bits || fieldEntry.vectorEncoding != VectorEncoding.BYTE) { return; } - RandomVectorScorer scorer = flatVectorsReader.getRandomVectorScorer(field, target); - HnswGraphSearcher.search( - scorer, - new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc), - getGraph(fieldEntry), - scorer.getAcceptOrds(acceptDocs)); + final RandomVectorScorer scorer = flatVectorsReader.getRandomVectorScorer(field, target); + final KnnCollector collector = + new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc); + final Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs); + if (knnCollector.k() < scorer.maxOrd()) { + HnswGraphSearcher.search(scorer, collector, getGraph(fieldEntry), acceptedOrds); + } else { + // if k is larger than the number of vectors, we can just iterate over all vectors + // and collect them + for (int i = 0; i < scorer.maxOrd(); i++) { + if (acceptedOrds == null || acceptedOrds.get(i)) { + knnCollector.incVisitedCount(1); + knnCollector.collect(scorer.ordToDoc(i), scorer.score(i)); + } + } + } } @Override diff --git a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java index fb4775712a03..6e3f90fff595 100644 --- a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java @@ -17,6 +17,8 @@ package org.apache.lucene.search; import static com.carrotsearch.randomizedtesting.RandomizedTest.frequently; +import static com.carrotsearch.randomizedtesting.RandomizedTest.randomBoolean; +import static com.carrotsearch.randomizedtesting.RandomizedTest.randomIntBetween; import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; @@ -222,7 +224,7 @@ public void testDimensionMismatch() throws IOException { getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0}); IndexReader reader = DirectoryReader.open(indexStore)) { IndexSearcher searcher = newSearcher(reader); - AbstractKnnVectorQuery kvq = getKnnVectorQuery("field", new float[] {0}, 10); + AbstractKnnVectorQuery kvq = getKnnVectorQuery("field", new float[] {0}, 1); IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> searcher.search(kvq, 10)); assertEquals("vector query dimension: 1 differs from field dimension: 2", e.getMessage()); @@ -779,6 +781,16 @@ Directory getIndexStore( doc.add(getKnnVectorField(field, contents[i], vectorSimilarityFunction)); doc.add(new StringField("id", "id" + i, Field.Store.YES)); writer.addDocument(doc); + if (randomBoolean()) { + // Add some documents without a vector + for (int j = 0; j < randomIntBetween(1, 5); j++) { + doc = new Document(); + doc.add(new StringField("other", "value", Field.Store.NO)); + // Add fields that will be matched by our test filters but won't have vectors + doc.add(new StringField("id", "id" + j, Field.Store.YES)); + writer.addDocument(doc); + } + } } // Add some documents without a vector for (int i = 0; i < 5; i++) {