Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move BooleanScorer to work on top of Scorers rather than BulkScorers. #13931

Merged
merged 5 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ Optimizations

* GITHUB#13930: Use growNoCopy when copying bytes in BytesRefBuilder. (Ignacio Vera)

* GITHUB#13931: Refactored `BooleanScorer` to evaluate matches of sub clauses
using the `Scorer` abstraction rather than the `BulkScorer` abstraction. This
speeds up exhaustive evaluation of disjunctions of term queries.
(Adrien Grand)

Bug Fixes
---------------------
* GITHUB#13832: Fixed an issue where the DefaultPassageFormatter.format method did not format passages as intended
Expand Down
191 changes: 87 additions & 104 deletions lucene/core/src/java/org/apache/lucene/search/BooleanScorer.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@
import java.util.Arrays;
import java.util.Collection;
import java.util.Objects;
import org.apache.lucene.internal.hppc.LongArrayList;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.PriorityQueue;

/**
* {@link BulkScorer} that is used for pure disjunctions and disjunctions that have low values of
* {@link BooleanQuery.Builder#setMinimumNumberShouldMatch(int)} and dense clauses. This scorer
* scores documents by batches of 2048 docs.
* scores documents by batches of 4,096 docs.
*/
final class BooleanScorer extends BulkScorer {

Expand All @@ -41,71 +42,32 @@ static class Bucket {
int freq;
}

private class BulkScorerAndDoc {
final BulkScorer scorer;
final long cost;
int next;

BulkScorerAndDoc(BulkScorer scorer) {
this.scorer = scorer;
this.cost = scorer.cost();
this.next = -1;
}

void advance(int min) throws IOException {
score(orCollector, null, min, min);
}

void score(LeafCollector collector, Bits acceptDocs, int min, int max) throws IOException {
next = scorer.score(collector, acceptDocs, min, max);
}
}

// See WANDScorer for an explanation
private static long cost(Collection<BulkScorer> scorers, int minShouldMatch) {
final PriorityQueue<BulkScorer> pq =
new PriorityQueue<BulkScorer>(scorers.size() - minShouldMatch + 1) {
@Override
protected boolean lessThan(BulkScorer a, BulkScorer b) {
return a.cost() > b.cost();
}
};
for (BulkScorer scorer : scorers) {
pq.insertWithOverflow(scorer);
}
long cost = 0;
for (BulkScorer scorer = pq.pop(); scorer != null; scorer = pq.pop()) {
cost += scorer.cost();
}
return cost;
}

static final class HeadPriorityQueue extends PriorityQueue<BulkScorerAndDoc> {
static final class HeadPriorityQueue extends PriorityQueue<DisiWrapper> {

public HeadPriorityQueue(int maxSize) {
super(maxSize);
}

@Override
protected boolean lessThan(BulkScorerAndDoc a, BulkScorerAndDoc b) {
return a.next < b.next;
protected boolean lessThan(DisiWrapper a, DisiWrapper b) {
return a.doc < b.doc;
}
}

static final class TailPriorityQueue extends PriorityQueue<BulkScorerAndDoc> {
static final class TailPriorityQueue extends PriorityQueue<DisiWrapper> {

public TailPriorityQueue(int maxSize) {
super(maxSize);
}

@Override
protected boolean lessThan(BulkScorerAndDoc a, BulkScorerAndDoc b) {
protected boolean lessThan(DisiWrapper a, DisiWrapper b) {
return a.cost < b.cost;
}

public BulkScorerAndDoc get(int i) {
public DisiWrapper get(int i) {
Objects.checkIndex(i, size());
return (BulkScorerAndDoc) getHeapArray()[1 + i];
return (DisiWrapper) getHeapArray()[1 + i];
}
}

Expand All @@ -115,39 +77,14 @@ public BulkScorerAndDoc get(int i) {
// This is basically an inlined FixedBitSet... seems to help with bound checks
final long[] matching = new long[SET_SIZE];

final BulkScorerAndDoc[] leads;
final DisiWrapper[] leads;
final HeadPriorityQueue head;
final TailPriorityQueue tail;
final Score score = new Score();
final int minShouldMatch;
final long cost;
final boolean needsScores;

final class OrCollector implements LeafCollector {
Scorable scorer;

@Override
public void setScorer(Scorable scorer) {
this.scorer = scorer;
}

@Override
public void collect(int doc) throws IOException {
final int i = doc & MASK;
final int idx = i >>> 6;
matching[idx] |= 1L << i;
if (buckets != null) {
final Bucket bucket = buckets[i];
bucket.freq++;
if (needsScores) {
bucket.score += scorer.score();
}
}
}
}

final OrCollector orCollector = new OrCollector();

final class DocIdStreamView extends DocIdStream {

int base;
Expand Down Expand Up @@ -194,7 +131,7 @@ public int count() throws IOException {

private final DocIdStreamView docIdStreamView = new DocIdStreamView();

BooleanScorer(Collection<BulkScorer> scorers, int minShouldMatch, boolean needsScores) {
BooleanScorer(Collection<Scorer> scorers, int minShouldMatch, boolean needsScores) {
if (minShouldMatch < 1 || minShouldMatch > scorers.size()) {
throw new IllegalArgumentException(
"minShouldMatch should be within 1..num_scorers. Got " + minShouldMatch);
Expand All @@ -211,38 +148,71 @@ public int count() throws IOException {
} else {
buckets = null;
}
this.leads = new BulkScorerAndDoc[scorers.size()];
this.leads = new DisiWrapper[scorers.size()];
this.head = new HeadPriorityQueue(scorers.size() - minShouldMatch + 1);
this.tail = new TailPriorityQueue(minShouldMatch - 1);
this.minShouldMatch = minShouldMatch;
this.needsScores = needsScores;
for (BulkScorer scorer : scorers) {
final BulkScorerAndDoc evicted = tail.insertWithOverflow(new BulkScorerAndDoc(scorer));
LongArrayList costs = new LongArrayList(scorers.size());
for (Scorer scorer : scorers) {
DisiWrapper w = new DisiWrapper(scorer);
costs.add(w.cost);
final DisiWrapper evicted = tail.insertWithOverflow(w);
if (evicted != null) {
head.add(evicted);
}
}
this.cost = cost(scorers, minShouldMatch);
this.cost = ScorerUtil.costWithMinShouldMatch(costs.stream(), costs.size(), minShouldMatch);
}

@Override
public long cost() {
return cost;
}

private void scoreDisiWrapperIntoBitSet(DisiWrapper w, Bits acceptDocs, int min, int max)
throws IOException {
boolean needsScores = BooleanScorer.this.needsScores;
long[] matching = BooleanScorer.this.matching;
Bucket[] buckets = BooleanScorer.this.buckets;

DocIdSetIterator it = w.iterator;
Scorer scorer = w.scorer;
int doc = w.doc;
if (doc < min) {
doc = it.advance(min);
}
for (; doc < max; doc = it.nextDoc()) {
if (acceptDocs == null || acceptDocs.get(doc)) {
final int i = doc & MASK;
final int idx = i >> 6;
matching[idx] |= 1L << i;
if (buckets != null) {
final Bucket bucket = buckets[i];
bucket.freq++;
if (needsScores) {
bucket.score += scorer.score();
}
}
}
}

w.doc = doc;
}

private void scoreWindowIntoBitSetAndReplay(
LeafCollector collector,
Bits acceptDocs,
int base,
int min,
int max,
BulkScorerAndDoc[] scorers,
DisiWrapper[] scorers,
int numScorers)
throws IOException {
for (int i = 0; i < numScorers; ++i) {
final BulkScorerAndDoc scorer = scorers[i];
assert scorer.next < max;
scorer.score(orCollector, acceptDocs, min, max);
final DisiWrapper w = scorers[i];
assert w.doc < max;
scoreDisiWrapperIntoBitSet(w, acceptDocs, min, max);
}

docIdStreamView.base = base;
Expand All @@ -251,20 +221,20 @@ private void scoreWindowIntoBitSetAndReplay(
Arrays.fill(matching, 0L);
}

private BulkScorerAndDoc advance(int min) throws IOException {
private DisiWrapper advance(int min) throws IOException {
assert tail.size() == minShouldMatch - 1;
final HeadPriorityQueue head = this.head;
final TailPriorityQueue tail = this.tail;
BulkScorerAndDoc headTop = head.top();
BulkScorerAndDoc tailTop = tail.top();
while (headTop.next < min) {
DisiWrapper headTop = head.top();
DisiWrapper tailTop = tail.top();
while (headTop.doc < min) {
if (tailTop == null || headTop.cost <= tailTop.cost) {
headTop.advance(min);
headTop.doc = headTop.iterator.advance(min);
headTop = head.updateTop();
} else {
// swap the top of head and tail
final BulkScorerAndDoc previousHeadTop = headTop;
tailTop.advance(min);
final DisiWrapper previousHeadTop = headTop;
tailTop.doc = tailTop.iterator.advance(min);
headTop = head.updateTop(tailTop);
tailTop = tail.updateTop(previousHeadTop);
}
Expand All @@ -282,9 +252,11 @@ private void scoreWindowMultipleScorers(
throws IOException {
while (maxFreq < minShouldMatch && maxFreq + tail.size() >= minShouldMatch) {
// a match is still possible
final BulkScorerAndDoc candidate = tail.pop();
candidate.advance(windowMin);
if (candidate.next < windowMax) {
final DisiWrapper candidate = tail.pop();
if (candidate.doc < windowMin) {
candidate.doc = candidate.iterator.advance(windowMin);
}
if (candidate.doc < windowMax) {
leads[maxFreq++] = candidate;
} else {
head.add(candidate);
Expand All @@ -304,49 +276,60 @@ private void scoreWindowMultipleScorers(

// Push back scorers into head and tail
for (int i = 0; i < maxFreq; ++i) {
final BulkScorerAndDoc evicted = head.insertWithOverflow(leads[i]);
final DisiWrapper evicted = head.insertWithOverflow(leads[i]);
if (evicted != null) {
tail.add(evicted);
}
}
}

private void scoreWindowSingleScorer(
BulkScorerAndDoc bulkScorer,
DisiWrapper w,
LeafCollector collector,
Bits acceptDocs,
int windowMin,
int windowMax,
int max)
throws IOException {
assert tail.size() == 0;
final int nextWindowBase = head.top().next & ~MASK;
final int nextWindowBase = head.top().doc & ~MASK;
final int end = Math.max(windowMax, Math.min(max, nextWindowBase));

bulkScorer.score(collector, acceptDocs, windowMin, end);
DocIdSetIterator it = w.iterator;
int doc = w.doc;
if (doc < windowMin) {
doc = it.advance(windowMin);
}
collector.setScorer(w.scorer);
for (; doc < end; doc = it.nextDoc()) {
if (acceptDocs == null || acceptDocs.get(doc)) {
collector.collect(doc);
}
}
w.doc = doc;

// reset the scorer that should be used for the general case
collector.setScorer(score);
}

private BulkScorerAndDoc scoreWindow(
BulkScorerAndDoc top, LeafCollector collector, Bits acceptDocs, int min, int max)
private DisiWrapper scoreWindow(
DisiWrapper top, LeafCollector collector, Bits acceptDocs, int min, int max)
throws IOException {
final int windowBase = top.next & ~MASK; // find the window that the next match belongs to
final int windowBase = top.doc & ~MASK; // find the window that the next match belongs to
final int windowMin = Math.max(min, windowBase);
final int windowMax = Math.min(max, windowBase + SIZE);

// Fill 'leads' with all scorers from 'head' that are in the right window
leads[0] = head.pop();
int maxFreq = 1;
while (head.size() > 0 && head.top().next < windowMax) {
while (head.size() > 0 && head.top().doc < windowMax) {
leads[maxFreq++] = head.pop();
}

if (minShouldMatch == 1 && maxFreq == 1) {
// special case: only one scorer can match in the current window,
// we can collect directly
final BulkScorerAndDoc bulkScorer = leads[0];
final DisiWrapper bulkScorer = leads[0];
scoreWindowSingleScorer(bulkScorer, collector, acceptDocs, windowMin, windowMax, max);
return head.add(bulkScorer);
} else {
Expand All @@ -360,11 +343,11 @@ private BulkScorerAndDoc scoreWindow(
public int score(LeafCollector collector, Bits acceptDocs, int min, int max) throws IOException {
collector.setScorer(score);

BulkScorerAndDoc top = advance(min);
while (top.next < max) {
DisiWrapper top = advance(min);
while (top.doc < max) {
top = scoreWindow(top, collector, acceptDocs, min, max);
}

return top.next;
return top.doc;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,9 @@ BulkScorer optionalBulkScorer() throws IOException {
return new MaxScoreBulkScorer(maxDoc, optionalScorers);
}

List<BulkScorer> optional = new ArrayList<BulkScorer>();
List<Scorer> optional = new ArrayList<Scorer>();
for (ScorerSupplier ss : subs.get(Occur.SHOULD)) {
optional.add(ss.bulkScorer());
optional.add(ss.get(Long.MAX_VALUE));
}

return new BooleanScorer(optional, Math.max(1, minShouldMatch), scoreMode.needsScores());
Expand Down
Loading