diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 74d81c5f2dea..c65b584beb19 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -155,6 +155,8 @@ Bug Fixes * GITHUB#12388: JoinUtil queries were ignoring boosts. (Alan Woodward) +* GITHUB#12413: Fix HNSW graph search bug that potentially leaked unapproved docs (Ben Trent). + Other --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java index 1cd5183a993b..ab792ca3bd05 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java @@ -204,26 +204,26 @@ private static NeighborQueue search( if (initialEp == -1) { return new NeighborQueue(1, true); } - NeighborQueue results; - results = new NeighborQueue(1, false); - int[] eps = new int[] {graph.entryNode()}; - int numVisited = 0; - for (int level = graph.numLevels() - 1; level >= 1; level--) { - results.clear(); - graphSearcher.searchLevel(results, query, 1, level, eps, vectors, graph, null, visitedLimit); - - numVisited += results.visitedCount(); - visitedLimit -= results.visitedCount(); - - if (results.incomplete()) { - results.setVisitedCount(numVisited); - return results; - } - eps[0] = results.pop(); + int[] epAndVisited = graphSearcher.findBestEntryPoint(query, vectors, graph, visitedLimit); + int numVisited = epAndVisited[1]; + int ep = epAndVisited[0]; + if (ep == -1) { + NeighborQueue results = new NeighborQueue(1, false); + results.setVisitedCount(numVisited); + results.markIncomplete(); + return results; } - results = new NeighborQueue(topK, false); + NeighborQueue results = new NeighborQueue(topK, false); graphSearcher.searchLevel( - results, query, topK, 0, eps, vectors, graph, acceptOrds, visitedLimit); + results, + query, + topK, + 0, + new int[] {ep}, + vectors, + graph, + acceptOrds, + visitedLimit - numVisited); results.setVisitedCount(results.visitedCount() + numVisited); return results; } @@ -256,6 +256,56 @@ public NeighborQueue searchLevel( return results; } + /** + * Function to find the best entry point from which to search the zeroth graph layer. + * + * @param query vector query with which to search + * @param vectors random access vector values + * @param graph the HNSWGraph + * @param visitLimit How many vectors are allowed to be visited + * @return An integer array whose first element is the best entry point, and second is the number + * of candidates visited. Entry point of `-1` indicates visitation limit exceed + * @throws IOException When accessing the vector fails + */ + private int[] findBestEntryPoint( + T query, RandomAccessVectorValues vectors, HnswGraph graph, int visitLimit) + throws IOException { + int size = graph.size(); + int visitedCount = 1; + prepareScratchState(vectors.size()); + int currentEp = graph.entryNode(); + float currentScore = compare(query, vectors, currentEp); + boolean foundBetter; + for (int level = graph.numLevels() - 1; level >= 1; level--) { + foundBetter = true; + visited.set(currentEp); + // Keep searching the given level until we stop finding a better candidate entry point + while (foundBetter) { + foundBetter = false; + graphSeek(graph, level, currentEp); + int friendOrd; + while ((friendOrd = graphNextNeighbor(graph)) != NO_MORE_DOCS) { + assert friendOrd < size : "friendOrd=" + friendOrd + "; size=" + size; + if (visited.getAndSet(friendOrd)) { + continue; + } + if (visitedCount >= visitLimit) { + return new int[] {-1, visitedCount}; + } + float friendSimilarity = compare(query, vectors, friendOrd); + visitedCount++; + if (friendSimilarity > currentScore + || (friendSimilarity == currentScore && friendOrd < currentEp)) { + currentScore = friendSimilarity; + currentEp = friendOrd; + foundBetter = true; + } + } + } + } + return new int[] {currentEp, visitedCount}; + } + /** * Add the closest neighbors found to a priority queue (heap). These are returned in REVERSE * proximity order -- the most distant neighbor of the topK found, i.e. the one with the lowest