Skip to content

Commit

Permalink
Utilize exact kNN search when gathering k > numVectors in a segment (a…
Browse files Browse the repository at this point in the history
…pache#12806)

When requesting for k >= numVectors, it doesn't make sense to go through the HNSW graph. Even without a user supplied filter, we should not explore the HNSW graph if it contains fewer than k vectors.

One scenario where we may still explore the graph if k >= numVectors is when not every document has a vector and there are deleted docs. But, this commit significantly improves things regardless.
  • Loading branch information
benwtrent authored Nov 15, 2023
1 parent 5afc17d commit 05a336e
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 13 deletions.
2 changes: 2 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,8 @@ Optimizations

* GITHUB#12784: Cache buckets to speed up BytesRefHash#sort. (Guo Feng)

* GITHUB#12806: Utilize exact kNN search when gathering k >= numVectors in a segment (Ben Trent)

Changes in runtime behavior
---------------------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,12 +227,22 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits
|| fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
return;
}
RandomVectorScorer scorer = flatVectorsReader.getRandomVectorScorer(field, target);
HnswGraphSearcher.search(
scorer,
new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc),
getGraph(fieldEntry),
scorer.getAcceptOrds(acceptDocs));
final RandomVectorScorer scorer = flatVectorsReader.getRandomVectorScorer(field, target);
final KnnCollector collector =
new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc);
final Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs);
if (knnCollector.k() < scorer.maxOrd()) {
HnswGraphSearcher.search(scorer, collector, getGraph(fieldEntry), acceptedOrds);
} else {
// if k is larger than the number of vectors, we can just iterate over all vectors
// and collect them
for (int i = 0; i < scorer.maxOrd(); i++) {
if (acceptedOrds == null || acceptedOrds.get(i)) {
knnCollector.incVisitedCount(1);
knnCollector.collect(scorer.ordToDoc(i), scorer.score(i));
}
}
}
}

@Override
Expand All @@ -245,12 +255,22 @@ public void search(String field, byte[] target, KnnCollector knnCollector, Bits
|| fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
return;
}
RandomVectorScorer scorer = flatVectorsReader.getRandomVectorScorer(field, target);
HnswGraphSearcher.search(
scorer,
new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc),
getGraph(fieldEntry),
scorer.getAcceptOrds(acceptDocs));
final RandomVectorScorer scorer = flatVectorsReader.getRandomVectorScorer(field, target);
final KnnCollector collector =
new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc);
final Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs);
if (knnCollector.k() < scorer.maxOrd()) {
HnswGraphSearcher.search(scorer, collector, getGraph(fieldEntry), acceptedOrds);
} else {
// if k is larger than the number of vectors, we can just iterate over all vectors
// and collect them
for (int i = 0; i < scorer.maxOrd(); i++) {
if (acceptedOrds == null || acceptedOrds.get(i)) {
knnCollector.incVisitedCount(1);
knnCollector.collect(scorer.ordToDoc(i), scorer.score(i));
}
}
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package org.apache.lucene.search;

import static com.carrotsearch.randomizedtesting.RandomizedTest.frequently;
import static com.carrotsearch.randomizedtesting.RandomizedTest.randomBoolean;
import static com.carrotsearch.randomizedtesting.RandomizedTest.randomIntBetween;
import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;

Expand Down Expand Up @@ -222,7 +224,7 @@ public void testDimensionMismatch() throws IOException {
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
IndexReader reader = DirectoryReader.open(indexStore)) {
IndexSearcher searcher = newSearcher(reader);
AbstractKnnVectorQuery kvq = getKnnVectorQuery("field", new float[] {0}, 10);
AbstractKnnVectorQuery kvq = getKnnVectorQuery("field", new float[] {0}, 1);
IllegalArgumentException e =
expectThrows(IllegalArgumentException.class, () -> searcher.search(kvq, 10));
assertEquals("vector query dimension: 1 differs from field dimension: 2", e.getMessage());
Expand Down Expand Up @@ -779,6 +781,16 @@ Directory getIndexStore(
doc.add(getKnnVectorField(field, contents[i], vectorSimilarityFunction));
doc.add(new StringField("id", "id" + i, Field.Store.YES));
writer.addDocument(doc);
if (randomBoolean()) {
// Add some documents without a vector
for (int j = 0; j < randomIntBetween(1, 5); j++) {
doc = new Document();
doc.add(new StringField("other", "value", Field.Store.NO));
// Add fields that will be matched by our test filters but won't have vectors
doc.add(new StringField("id", "id" + j, Field.Store.YES));
writer.addDocument(doc);
}
}
}
// Add some documents without a vector
for (int i = 0; i < 5; i++) {
Expand Down

0 comments on commit 05a336e

Please sign in to comment.