Skip to content

Commit

Permalink
Optimize in clause evaluation (#11557)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jackie-Jiang authored Sep 14, 2023
1 parent a8b5fe7 commit 8abe86b
Show file tree
Hide file tree
Showing 5 changed files with 253 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import org.apache.pinot.spi.data.FieldSpec.DataType;
import org.apache.pinot.spi.utils.BooleanUtils;
import org.apache.pinot.spi.utils.ByteArray;
import org.apache.pinot.spi.utils.CommonConstants;
import org.apache.pinot.spi.utils.CommonConstants.Broker.Request.QueryOptionKey;
import org.apache.pinot.spi.utils.TimestampUtils;


Expand Down Expand Up @@ -145,23 +145,35 @@ public static IntSet getDictIdSet(BaseInPredicate inPredicate, Dictionary dictio
}
break;
case STRING:
if (queryContext == null || values.size() <= Integer.parseInt(queryContext.getQueryOptions()
.getOrDefault(CommonConstants.Broker.Request.QueryOptionKey.IN_PREDICATE_SORT_THRESHOLD,
CommonConstants.Broker.Request.QueryOptionValue.DEFAULT_IN_PREDICATE_SORT_THRESHOLD))) {
for (String value : values) {
int dictId = dictionary.indexOf(value);
if (dictId >= 0) {
dictIdSet.add(dictId);
}
if (queryContext == null || values.size() <= 1) {
dictionary.getDictIds(values, dictIdSet);
break;
}
Dictionary.SortedBatchLookupAlgorithm lookupAlgorithm =
Dictionary.SortedBatchLookupAlgorithm.DIVIDE_BINARY_SEARCH;
String inPredicateLookupAlgorithm =
queryContext.getQueryOptions().get(QueryOptionKey.IN_PREDICATE_LOOKUP_ALGORITHM);
if (inPredicateLookupAlgorithm != null) {
try {
lookupAlgorithm = Dictionary.SortedBatchLookupAlgorithm.valueOf(inPredicateLookupAlgorithm.toUpperCase());
} catch (Exception e) {
throw new IllegalArgumentException("Illegal IN predicate lookup algorithm: " + inPredicateLookupAlgorithm);
}
}
if (lookupAlgorithm == Dictionary.SortedBatchLookupAlgorithm.PLAIN_BINARY_SEARCH) {
dictionary.getDictIds(values, dictIdSet);
break;
}
if (Boolean.parseBoolean(queryContext.getQueryOptions().get(QueryOptionKey.IN_PREDICATE_PRE_SORTED))) {
dictionary.getDictIds(values, dictIdSet, lookupAlgorithm);
} else {
List<String> sortedValues =
//noinspection unchecked
dictionary.getDictIds(
queryContext.getOrComputeSharedValue(List.class, Equivalence.identity().wrap(inPredicate), k -> {
List<String> copyValues = new ArrayList<>(values);
copyValues.sort(null);
return copyValues;
});
dictionary.getDictIds(sortedValues, dictIdSet);
List<String> sortedValues = new ArrayList<>(values);
sortedValues.sort(null);
return sortedValues;
}), dictIdSet, lookupAlgorithm);
}
break;
case BYTES:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.pinot.perf;

import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
import it.unimi.dsi.fastutil.ints.IntSet;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang.RandomStringUtils;
import org.apache.pinot.segment.local.segment.creator.impl.SegmentDictionaryCreator;
import org.apache.pinot.segment.local.segment.index.readers.StringDictionary;
import org.apache.pinot.segment.spi.V1Constants;
import org.apache.pinot.segment.spi.index.reader.Dictionary;
import org.apache.pinot.segment.spi.memory.PinotDataBuffer;
import org.apache.pinot.spi.data.DimensionFieldSpec;
import org.apache.pinot.spi.data.FieldSpec.DataType;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.TearDown;
import org.openjdk.jmh.runner.Runner;
import org.openjdk.jmh.runner.options.ChainedOptionsBuilder;
import org.openjdk.jmh.runner.options.OptionsBuilder;
import org.openjdk.jmh.runner.options.TimeValue;


@State(Scope.Benchmark)
public class BenchmarkDictionaryLookup {
private static final File INDEX_DIR = new File(FileUtils.getTempDirectory(), "BenchmarkDictionaryLookup");
private static final int MAX_LENGTH = 10;
private static final String COLUMN_NAME = "column";

@Param({"100", "1000", "10000", "100000", "1000000"})
private int _cardinality;

@Param({"1", "2", "4", "8", "16", "32", "64", "100"})
private int _lookupPercentage;

private StringDictionary _dictionary;
private List<String> _lookupValues;

@Setup
public void setUp()
throws IOException {
FileUtils.deleteDirectory(INDEX_DIR);
Set<String> uniqueValues = new HashSet<>();
while (uniqueValues.size() < _cardinality) {
uniqueValues.add(RandomStringUtils.randomAscii(MAX_LENGTH));
}
String[] sortedValues = uniqueValues.toArray(new String[0]);
Arrays.sort(sortedValues);
int maxLength;
try (SegmentDictionaryCreator creator = new SegmentDictionaryCreator(
new DimensionFieldSpec(COLUMN_NAME, DataType.STRING, true), INDEX_DIR, false)) {
creator.build(sortedValues);
maxLength = creator.getNumBytesPerEntry();
}
_dictionary = new StringDictionary(
PinotDataBuffer.mapReadOnlyBigEndianFile(new File(INDEX_DIR, COLUMN_NAME + V1Constants.Dict.FILE_EXTENSION)),
_cardinality, maxLength);
int numLookupValues = _cardinality * _lookupPercentage / 100;
if (numLookupValues == _cardinality) {
_lookupValues = Arrays.asList(sortedValues);
} else {
IntSet lookupValueIds = new IntOpenHashSet();
while (lookupValueIds.size() < numLookupValues) {
lookupValueIds.add((int) (Math.random() * _cardinality));
}
int[] sortedValueIds = lookupValueIds.toIntArray();
Arrays.sort(sortedValueIds);
_lookupValues = new ArrayList<>(numLookupValues);
for (int lookupValueId : sortedValueIds) {
_lookupValues.add(sortedValues[lookupValueId]);
}
}
}

@TearDown
public void tearDown()
throws Exception {
FileUtils.deleteDirectory(INDEX_DIR);
}

@Benchmark
@BenchmarkMode(Mode.Throughput)
@OutputTimeUnit(TimeUnit.SECONDS)
public int benchmarkDivideBinarySearch() {
IntSet dictIds = new IntOpenHashSet();
_dictionary.getDictIds(_lookupValues, dictIds, Dictionary.SortedBatchLookupAlgorithm.DIVIDE_BINARY_SEARCH);
return dictIds.size();
}

@Benchmark
@BenchmarkMode(Mode.Throughput)
@OutputTimeUnit(TimeUnit.SECONDS)
public int benchmarkScan() {
IntSet dictIds = new IntOpenHashSet();
_dictionary.getDictIds(_lookupValues, dictIds, Dictionary.SortedBatchLookupAlgorithm.SCAN);
return dictIds.size();
}

@Benchmark
@BenchmarkMode(Mode.Throughput)
@OutputTimeUnit(TimeUnit.SECONDS)
public int benchmarkPlainBinarySearch() {
IntSet dictIds = new IntOpenHashSet();
_dictionary.getDictIds(_lookupValues, dictIds);
return dictIds.size();
}

public static void main(String[] args)
throws Exception {
ChainedOptionsBuilder opt =
new OptionsBuilder().include(BenchmarkDictionaryLookup.class.getSimpleName()).warmupTime(TimeValue.seconds(3))
.warmupIterations(1).measurementTime(TimeValue.seconds(5)).measurementIterations(1).forks(1);
new Runner(opt.build()).run();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -282,34 +282,79 @@ protected byte[] getBuffer() {
return new byte[_numBytesPerValue];
}

/**
* Returns the dictionary id for the given sorted values.
* @param sortedValues
* @param dictIds
*/
@Override
public void getDictIds(List<String> sortedValues, IntSet dictIds) {
int valueIdx = 0;
int dictIdx = 0;
public void getDictIds(List<String> sortedValues, IntSet dictIds, SortedBatchLookupAlgorithm algorithm) {
switch (algorithm) {
case DIVIDE_BINARY_SEARCH:
getDictIdsDivideBinarySearch(sortedValues, 0, sortedValues.size(), 0, _length, dictIds);
break;
case SCAN:
getDictIdsScan(sortedValues, dictIds);
break;
default:
throw new IllegalStateException("Unsupported sorted batch lookup algorithm: " + algorithm);
}
}

private void getDictIdsDivideBinarySearch(List<String> sortedValues, int startIndex, int endIndex, int startDictId,
int endDictId, IntSet dictIds) {
if (startIndex >= endIndex || startDictId >= endDictId) {
return;
}
int midIndex = (startIndex + endIndex) >>> 1;
String midValue = sortedValues.get(midIndex);
int dictId = binarySearch(midValue, startDictId, endDictId);
if (dictId >= 0) {
dictIds.add(dictId);
getDictIdsDivideBinarySearch(sortedValues, startIndex, midIndex, startDictId, dictId, dictIds);
getDictIdsDivideBinarySearch(sortedValues, midIndex + 1, endIndex, dictId + 1, endDictId, dictIds);
} else {
dictId = -dictId - 1;
getDictIdsDivideBinarySearch(sortedValues, startIndex, midIndex, startDictId, dictId, dictIds);
getDictIdsDivideBinarySearch(sortedValues, midIndex + 1, endIndex, dictId, endDictId, dictIds);
}
}

private int binarySearch(String value, int startDictId, int endDictId) {
int low = startDictId;
int high = endDictId - 1;
byte[] utf8 = value.getBytes(UTF_8);
while (low <= high) {
int mid = (low + high) >>> 1;
int compareResult = _valueReader.compareUtf8Bytes(mid, _numBytesPerValue, utf8);
if (compareResult < 0) {
low = mid + 1;
} else if (compareResult > 0) {
high = mid - 1;
} else {
return mid;
}
}
return -(low + 1);
}

private void getDictIdsScan(List<String> sortedValues, IntSet dictIds) {
int valueId = 0;
int dictId = 0;
byte[] utf8 = null;
boolean needNewUtf8 = true;
int sortedValuesSize = sortedValues.size();
int numSortedValues = sortedValues.size();
int dictLength = length();
while (valueIdx < sortedValuesSize && dictIdx < dictLength) {
while (valueId < numSortedValues && dictId < dictLength) {
if (needNewUtf8) {
utf8 = sortedValues.get(valueIdx).getBytes(StandardCharsets.UTF_8);
utf8 = sortedValues.get(valueId).getBytes(StandardCharsets.UTF_8);
}
int comparison = _valueReader.compareUtf8Bytes(dictIdx, _numBytesPerValue, utf8);
int comparison = _valueReader.compareUtf8Bytes(dictId, _numBytesPerValue, utf8);
if (comparison == 0) {
dictIds.add(dictIdx);
dictIdx++;
valueIdx++;
dictIds.add(dictId);
dictId++;
valueId++;
needNewUtf8 = true;
} else if (comparison > 0) {
valueIdx++;
valueId++;
needNewUtf8 = true;
} else {
dictIdx++;
dictId++;
needNewUtf8 = false;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,17 +244,26 @@ default void readBytesValues(int[] dictIds, int length, byte[][] outValues) {
}
}

/**
* Returns the dictIds for the given sorted values. This method is for the IN/NOT IN predicate evaluation.
* @param sortedValues
* @param dictIds
*/
default void getDictIds(List<String> sortedValues, IntSet dictIds) {
for (String value : sortedValues) {
default void getDictIds(List<String> values, IntSet dictIds) {
for (String value : values) {
int dictId = indexOf(value);
if (dictId >= 0) {
dictIds.add(dictId);
}
}
}

/**
* Returns the dictIds for the given sorted values. This method is for the IN/NOT IN predicate evaluation.
*/
default void getDictIds(List<String> sortedValues, IntSet dictIds, SortedBatchLookupAlgorithm algorithm) {
getDictIds(sortedValues, dictIds);
}

enum SortedBatchLookupAlgorithm {
DIVIDE_BINARY_SEARCH, SCAN,
// Plain binary search should not be used because it does not require sorting the values. We keep it here as a valid
// query option value, which should be handled with the unsorted values' algorithm.
PLAIN_BINARY_SEARCH
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,8 @@ public static class QueryOptionKey {
public static final String GROUP_TRIM_THRESHOLD = "groupTrimThreshold";
public static final String STAGE_PARALLELISM = "stageParallelism";

// Handle IN predicate evaluation for big IN lists
public static final String IN_PREDICATE_SORT_THRESHOLD = "inPredicateSortThreshold";
public static final String IN_PREDICATE_PRE_SORTED = "inPredicatePreSorted";
public static final String IN_PREDICATE_LOOKUP_ALGORITHM = "inPredicateLookupAlgorithm";

public static final String DROP_RESULTS = "dropResults";

Expand All @@ -365,7 +365,6 @@ public static class QueryOptionKey {
}

public static class QueryOptionValue {
public static final String DEFAULT_IN_PREDICATE_SORT_THRESHOLD = "1000";
public static final int DEFAULT_MAX_STREAMING_PENDING_BLOCKS = 100;
}
}
Expand Down

0 comments on commit 8abe86b

Please sign in to comment.