From cf52567dadbaec2673904162f88077d4c2426632 Mon Sep 17 00:00:00 2001 From: Yash Mayya Date: Fri, 30 Aug 2024 20:35:35 +0530 Subject: [PATCH] Improve null handling performance for nullable single input aggregation functions (#13791) Modify aggregation functions that were not extending NullableSingleInputAggregationFunction to do so and optimize their code to use the methods included there. --- .../org/apache/pinot/core/data/table/Key.java | 5 + .../apache/pinot/core/data/table/Record.java | 5 + .../function/AvgAggregationFunction.java | 139 +--- .../function/AvgMVAggregationFunction.java | 4 +- .../BaseBooleanAggregationFunction.java | 45 +- ...eDistinctAggregateAggregationFunction.java | 350 +++------- .../function/CountAggregationFunction.java | 37 +- .../function/MaxAggregationFunction.java | 208 ++---- .../function/MinAggregationFunction.java | 207 ++---- ...ullableSingleInputAggregationFunction.java | 14 +- .../function/SumAggregationFunction.java | 199 ++---- .../SumPrecisionAggregationFunction.java | 319 +++------ .../function/VarianceAggregationFunction.java | 89 +-- .../array/ArrayAggDistinctDoubleFunction.java | 24 +- .../array/ArrayAggDistinctFloatFunction.java | 24 +- .../array/ArrayAggDistinctIntFunction.java | 23 +- .../array/ArrayAggDistinctLongFunction.java | 23 +- .../array/ArrayAggDistinctStringFunction.java | 21 +- .../array/ArrayAggDoubleFunction.java | 24 +- .../function/array/ArrayAggFloatFunction.java | 25 +- .../function/array/ArrayAggIntFunction.java | 24 +- .../function/array/ArrayAggLongFunction.java | 23 +- .../array/ArrayAggStringFunction.java | 21 +- .../array/BaseArrayAggDoubleFunction.java | 45 +- .../array/BaseArrayAggFloatFunction.java | 45 +- .../function/array/BaseArrayAggFunction.java | 72 +- .../array/BaseArrayAggIntFunction.java | 47 +- .../array/BaseArrayAggLongFunction.java | 43 +- .../array/BaseArrayAggStringFunction.java | 45 +- .../StatisticalAggregationFunctionUtils.java | 37 + .../AbstractAggregationFunctionTest.java | 60 +- .../function/ArrayAggFunctionTest.java | 661 ++++++++++++++++++ .../function/AvgAggregationFunctionTest.java | 144 ++++ .../BooleanAggregationFunctionTest.java | 178 +++++ .../CountAggregationFunctionTest.java | 64 +- .../DistinctAggregationFunctionTest.java | 281 ++++++++ .../function/MaxAggregationFunctionTest.java | 148 ++++ .../function/MinAggregationFunctionTest.java | 150 ++++ .../MinMaxRangeAggregationFunctionTest.java | 56 +- .../function/SumAggregationFunctionTest.java | 147 ++++ .../SumPrecisionAggregationFunctionTest.java | 155 ++++ .../VarianceAggregationFunctionTest.java | 168 +++++ .../apache/pinot/queries/BaseQueriesTest.java | 4 + .../apache/pinot/queries/FluentQueryTest.java | 29 + .../pinot/perf/SyntheticBlockValSets.java | 2 +- .../AbstractAggregationFunctionBenchmark.java | 21 +- .../AbstractAggregationQueryBenchmark.java | 88 +++ .../aggregation/BenchmarkAvgAggregation.java | 121 ++++ .../BenchmarkDistinctCountAggregation.java | 107 +++ .../aggregation/BenchmarkMinAggregation.java | 123 ++++ .../BenchmarkModeAggregation.java | 4 +- .../aggregation/BenchmarkSumAggregation.java | 127 ++++ .../perf/aggregation/BenchmarkSumQuery.java | 121 ++++ .../BenchmarkVarianceAggregation.java | 117 ++++ .../segment/local/customobject/AvgPair.java | 4 + .../local/customobject/VarianceTuple.java | 9 + 56 files changed, 3727 insertions(+), 1549 deletions(-) create mode 100644 pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/ArrayAggFunctionTest.java create mode 100644 pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/AvgAggregationFunctionTest.java create mode 100644 pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/BooleanAggregationFunctionTest.java create mode 100644 pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/DistinctAggregationFunctionTest.java create mode 100644 pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/MaxAggregationFunctionTest.java create mode 100644 pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/MinAggregationFunctionTest.java create mode 100644 pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/SumAggregationFunctionTest.java create mode 100644 pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/SumPrecisionAggregationFunctionTest.java create mode 100644 pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/VarianceAggregationFunctionTest.java rename pinot-perf/src/main/java/org/apache/pinot/perf/{ => aggregation}/AbstractAggregationFunctionBenchmark.java (91%) create mode 100644 pinot-perf/src/main/java/org/apache/pinot/perf/aggregation/AbstractAggregationQueryBenchmark.java create mode 100644 pinot-perf/src/main/java/org/apache/pinot/perf/aggregation/BenchmarkAvgAggregation.java create mode 100644 pinot-perf/src/main/java/org/apache/pinot/perf/aggregation/BenchmarkDistinctCountAggregation.java create mode 100644 pinot-perf/src/main/java/org/apache/pinot/perf/aggregation/BenchmarkMinAggregation.java rename pinot-perf/src/main/java/org/apache/pinot/perf/{ => aggregation}/BenchmarkModeAggregation.java (97%) create mode 100644 pinot-perf/src/main/java/org/apache/pinot/perf/aggregation/BenchmarkSumAggregation.java create mode 100644 pinot-perf/src/main/java/org/apache/pinot/perf/aggregation/BenchmarkSumQuery.java create mode 100644 pinot-perf/src/main/java/org/apache/pinot/perf/aggregation/BenchmarkVarianceAggregation.java diff --git a/pinot-core/src/main/java/org/apache/pinot/core/data/table/Key.java b/pinot-core/src/main/java/org/apache/pinot/core/data/table/Key.java index 8854b3a6e97..795123332e2 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/data/table/Key.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/data/table/Key.java @@ -58,4 +58,9 @@ public boolean equals(Object o) { public int hashCode() { return Arrays.hashCode(_values); } + + @Override + public String toString() { + return Arrays.toString(_values); + } } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/data/table/Record.java b/pinot-core/src/main/java/org/apache/pinot/core/data/table/Record.java index a2930ddb010..2f79fed0e75 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/data/table/Record.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/data/table/Record.java @@ -64,4 +64,9 @@ public boolean equals(Object o) { public int hashCode() { return Arrays.hashCode(_values); } + + @Override + public String toString() { + return Arrays.toString(_values); + } } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AvgAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AvgAggregationFunction.java index 9dbfb6cec57..aa74df1ac14 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AvgAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AvgAggregationFunction.java @@ -31,20 +31,17 @@ import org.apache.pinot.segment.local.customobject.AvgPair; import org.apache.pinot.segment.spi.AggregationFunctionType; import org.apache.pinot.spi.data.FieldSpec.DataType; -import org.roaringbitmap.RoaringBitmap; -public class AvgAggregationFunction extends BaseSingleInputAggregationFunction { +public class AvgAggregationFunction extends NullableSingleInputAggregationFunction { private static final double DEFAULT_FINAL_RESULT = Double.NEGATIVE_INFINITY; - private final boolean _nullHandlingEnabled; public AvgAggregationFunction(List arguments, boolean nullHandlingEnabled) { this(verifySingleArgument(arguments, "AVG"), nullHandlingEnabled); } protected AvgAggregationFunction(ExpressionContext expression, boolean nullHandlingEnabled) { - super(expression); - _nullHandlingEnabled = nullHandlingEnabled; + super(expression, nullHandlingEnabled); } @Override @@ -66,73 +63,37 @@ public GroupByResultHolder createGroupByResultHolder(int initialCapacity, int ma public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map blockValSetMap) { BlockValSet blockValSet = blockValSetMap.get(_expression); - if (_nullHandlingEnabled) { - RoaringBitmap nullBitmap = blockValSet.getNullBitmap(); - if (nullBitmap != null && !nullBitmap.isEmpty()) { - aggregateNullHandlingEnabled(length, aggregationResultHolder, blockValSet, nullBitmap); - return; - } - } if (blockValSet.getValueType() != DataType.BYTES) { double[] doubleValues = blockValSet.getDoubleValuesSV(); - double sum = 0.0; - for (int i = 0; i < length; i++) { - sum += doubleValues[i]; - } - setAggregationResult(aggregationResultHolder, sum, length); - } else { - // Serialized AvgPair - byte[][] bytesValues = blockValSet.getBytesValuesSV(); - double sum = 0.0; - long count = 0L; - for (int i = 0; i < length; i++) { - AvgPair value = ObjectSerDeUtils.AVG_PAIR_SER_DE.deserialize(bytesValues[i]); - sum += value.getSum(); - count += value.getCount(); - } - setAggregationResult(aggregationResultHolder, sum, count); - } - } - - private void aggregateNullHandlingEnabled(int length, AggregationResultHolder aggregationResultHolder, - BlockValSet blockValSet, RoaringBitmap nullBitmap) { - if (blockValSet.getValueType() != DataType.BYTES) { - double[] doubleValues = blockValSet.getDoubleValuesSV(); - if (nullBitmap.getCardinality() < length) { - double sum = 0.0; - long count = 0L; - // TODO: need to update the for-loop terminating condition to: i < length & i < doubleValues.length? - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { - sum += doubleValues[i]; - count++; - } + AvgPair avgPair = new AvgPair(); + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { + avgPair.apply(doubleValues[i], 1); } - setAggregationResult(aggregationResultHolder, sum, count); + }); + // Only set the aggregation result when there is at least one non-null input value + if (avgPair.getCount() != 0) { + updateAggregationResult(aggregationResultHolder, avgPair.getSum(), avgPair.getCount()); } - // Note: when all input values re null (nullBitmap.getCardinality() == values.length), avg is null. As a result, - // we don't call setAggregationResult. } else { // Serialized AvgPair byte[][] bytesValues = blockValSet.getBytesValuesSV(); - if (nullBitmap.getCardinality() < length) { - double sum = 0.0; - long count = 0L; - // TODO: need to update the for-loop terminating condition to: i < length & i < bytesValues.length? - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { - AvgPair value = ObjectSerDeUtils.AVG_PAIR_SER_DE.deserialize(bytesValues[i]); - sum += value.getSum(); - count += value.getCount(); - } + AvgPair avgPair = new AvgPair(); + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { + AvgPair value = ObjectSerDeUtils.AVG_PAIR_SER_DE.deserialize(bytesValues[i]); + avgPair.apply(value); } - setAggregationResult(aggregationResultHolder, sum, count); + }); + // Only set the aggregation result when there is at least one non-null input value + if (avgPair.getCount() != 0) { + updateAggregationResult(aggregationResultHolder, avgPair.getSum(), avgPair.getCount()); } } } - protected void setAggregationResult(AggregationResultHolder aggregationResultHolder, double sum, long count) { + protected void updateAggregationResult(AggregationResultHolder aggregationResultHolder, double sum, long count) { AvgPair avgPair = aggregationResultHolder.getResult(); if (avgPair == null) { aggregationResultHolder.setValue(new AvgPair(sum, count)); @@ -145,55 +106,23 @@ protected void setAggregationResult(AggregationResultHolder aggregationResultHol public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder, Map blockValSetMap) { BlockValSet blockValSet = blockValSetMap.get(_expression); - if (_nullHandlingEnabled) { - RoaringBitmap nullBitmap = blockValSet.getNullBitmap(); - if (nullBitmap != null && !nullBitmap.isEmpty()) { - aggregateGroupBySVNullHandlingEnabled(length, groupKeyArray, groupByResultHolder, blockValSet, nullBitmap); - return; - } - } - - if (blockValSet.getValueType() != DataType.BYTES) { - double[] doubleValues = blockValSet.getDoubleValuesSV(); - for (int i = 0; i < length; i++) { - setGroupByResult(groupKeyArray[i], groupByResultHolder, doubleValues[i], 1L); - } - } else { - // Serialized AvgPair - byte[][] bytesValues = blockValSet.getBytesValuesSV(); - for (int i = 0; i < length; i++) { - AvgPair avgPair = ObjectSerDeUtils.AVG_PAIR_SER_DE.deserialize(bytesValues[i]); - setGroupByResult(groupKeyArray[i], groupByResultHolder, avgPair.getSum(), avgPair.getCount()); - } - } - } - private void aggregateGroupBySVNullHandlingEnabled(int length, int[] groupKeyArray, - GroupByResultHolder groupByResultHolder, BlockValSet blockValSet, RoaringBitmap nullBitmap) { if (blockValSet.getValueType() != DataType.BYTES) { double[] doubleValues = blockValSet.getDoubleValuesSV(); - // TODO: need to update the for-loop terminating condition to: i < length & i < valueArray.length? - if (nullBitmap.getCardinality() < length) { - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { - int groupKey = groupKeyArray[i]; - setGroupByResult(groupKey, groupByResultHolder, doubleValues[i], 1L); - } + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { + updateGroupByResult(groupKeyArray[i], groupByResultHolder, doubleValues[i], 1L); } - } + }); } else { // Serialized AvgPair byte[][] bytesValues = blockValSet.getBytesValuesSV(); - // TODO: need to update the for-loop terminating condition to: i < length & i < valueArray.length? - if (nullBitmap.getCardinality() < length) { - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { - int groupKey = groupKeyArray[i]; - AvgPair avgPair = ObjectSerDeUtils.AVG_PAIR_SER_DE.deserialize(bytesValues[i]); - setGroupByResult(groupKey, groupByResultHolder, avgPair.getSum(), avgPair.getCount()); - } + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { + AvgPair avgPair = ObjectSerDeUtils.AVG_PAIR_SER_DE.deserialize(bytesValues[i]); + updateGroupByResult(groupKeyArray[i], groupByResultHolder, avgPair.getSum(), avgPair.getCount()); } - } + }); } } @@ -207,7 +136,7 @@ public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResult for (int i = 0; i < length; i++) { double value = doubleValues[i]; for (int groupKey : groupKeysArray[i]) { - setGroupByResult(groupKey, groupByResultHolder, value, 1L); + updateGroupByResult(groupKey, groupByResultHolder, value, 1L); } } } else { @@ -218,13 +147,13 @@ public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResult double sum = avgPair.getSum(); long count = avgPair.getCount(); for (int groupKey : groupKeysArray[i]) { - setGroupByResult(groupKey, groupByResultHolder, sum, count); + updateGroupByResult(groupKey, groupByResultHolder, sum, count); } } } } - protected void setGroupByResult(int groupKey, GroupByResultHolder groupByResultHolder, double sum, long count) { + protected void updateGroupByResult(int groupKey, GroupByResultHolder groupByResultHolder, double sum, long count) { AvgPair avgPair = groupByResultHolder.getResult(groupKey); if (avgPair == null) { groupByResultHolder.setValueForKey(groupKey, new AvgPair(sum, count)); @@ -237,7 +166,7 @@ protected void setGroupByResult(int groupKey, GroupByResultHolder groupByResultH public AvgPair extractAggregationResult(AggregationResultHolder aggregationResultHolder) { AvgPair avgPair = aggregationResultHolder.getResult(); if (avgPair == null) { - return _nullHandlingEnabled ? null : new AvgPair(0.0, 0L); + return _nullHandlingEnabled ? null : new AvgPair(); } return avgPair; } @@ -246,7 +175,7 @@ public AvgPair extractAggregationResult(AggregationResultHolder aggregationResul public AvgPair extractGroupByResult(GroupByResultHolder groupByResultHolder, int groupKey) { AvgPair avgPair = groupByResultHolder.getResult(groupKey); if (avgPair == null) { - return _nullHandlingEnabled ? null : new AvgPair(0.0, 0L); + return _nullHandlingEnabled ? null : new AvgPair(); } return avgPair; } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AvgMVAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AvgMVAggregationFunction.java index 50be99a05af..d0bc36fabbe 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AvgMVAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AvgMVAggregationFunction.java @@ -51,7 +51,7 @@ public void aggregate(int length, AggregationResultHolder aggregationResultHolde } count += values.length; } - setAggregationResult(aggregationResultHolder, sum, count); + updateAggregationResult(aggregationResultHolder, sum, count); } @Override @@ -81,6 +81,6 @@ private void aggregateOnGroupKey(int groupKey, GroupByResultHolder groupByResult sum += value; } long count = values.length; - setGroupByResult(groupKey, groupByResultHolder, sum, count); + updateGroupByResult(groupKey, groupByResultHolder, sum, count); } } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/BaseBooleanAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/BaseBooleanAggregationFunction.java index c6b1216ca99..5c14c44dcc5 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/BaseBooleanAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/BaseBooleanAggregationFunction.java @@ -30,16 +30,14 @@ import org.apache.pinot.core.query.aggregation.groupby.IntGroupByResultHolder; import org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder; import org.apache.pinot.spi.data.FieldSpec; -import org.roaringbitmap.RoaringBitmap; // TODO: change this to implement BaseSingleInputAggregationFunction when we get proper // handling of booleans in serialization - today this would fail because ColumnDataType#convert assumes // that the boolean is encoded as its stored type (an integer) -public abstract class BaseBooleanAggregationFunction extends BaseSingleInputAggregationFunction { +public abstract class BaseBooleanAggregationFunction extends NullableSingleInputAggregationFunction { private final BooleanMerge _merger; - private final boolean _nullHandlingEnabled; protected enum BooleanMerge { AND { @@ -84,8 +82,7 @@ int getDefaultValue() { protected BaseBooleanAggregationFunction(ExpressionContext expression, boolean nullHandlingEnabled, BooleanMerge merger) { - super(expression); - _nullHandlingEnabled = nullHandlingEnabled; + super(expression, nullHandlingEnabled); _merger = merger; } @@ -115,27 +112,22 @@ public void aggregate(int length, AggregationResultHolder aggregationResultHolde int[] bools = blockValSet.getIntValuesSV(); if (_nullHandlingEnabled) { - int agg = getInt(aggregationResultHolder.getResult()); - // early terminate on a per-block level to allow the // loop below to be more tightly optimized (avoid a branch) - if (_merger.isTerminal(agg)) { + if (_merger.isTerminal(getInt(aggregationResultHolder.getResult()))) { return; } - RoaringBitmap nullBitmap = blockValSet.getNullBitmap(); - if (nullBitmap == null) { - nullBitmap = new RoaringBitmap(); - } else if (nullBitmap.getCardinality() > length) { - return; - } - - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { - agg = _merger.merge(agg, bools[i]); - aggregationResultHolder.setValue((Object) agg); + Integer aggResult = foldNotNull(length, blockValSet, aggregationResultHolder.getResult(), + (acum, from, to) -> { + int innerBool = acum == null ? _merger.getDefaultValue() : acum; + for (int i = from; i < to; i++) { + innerBool = _merger.merge(innerBool, bools[i]); } - } + return innerBool; + }); + + aggregationResultHolder.setValue(aggResult); } else { int agg = aggregationResultHolder.getIntResult(); @@ -164,20 +156,13 @@ public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHol int[] bools = blockValSet.getIntValuesSV(); if (_nullHandlingEnabled) { - RoaringBitmap nullBitmap = blockValSet.getNullBitmap(); - if (nullBitmap == null) { - nullBitmap = new RoaringBitmap(); - } else if (nullBitmap.getCardinality() > length) { - return; - } - - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { int groupByKey = groupKeyArray[i]; int agg = getInt(groupByResultHolder.getResult(groupByKey)); groupByResultHolder.setValueForKey(groupByKey, (Object) _merger.merge(agg, bools[i])); } - } + }); } else { for (int i = 0; i < length; i++) { int groupByKey = groupKeyArray[i]; diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/BaseDistinctAggregateAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/BaseDistinctAggregateAggregationFunction.java index 49754b11f09..c29363b8854 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/BaseDistinctAggregateAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/BaseDistinctAggregateAggregationFunction.java @@ -47,15 +47,13 @@ */ @SuppressWarnings({"rawtypes", "unchecked"}) public abstract class BaseDistinctAggregateAggregationFunction - extends BaseSingleInputAggregationFunction { + extends NullableSingleInputAggregationFunction { private final AggregationFunctionType _functionType; - private final boolean _nullHandlingEnabled; protected BaseDistinctAggregateAggregationFunction(ExpressionContext expression, AggregationFunctionType aggregationFunctionType, boolean nullHandlingEnabled) { - super(expression); + super(expression, nullHandlingEnabled); _functionType = aggregationFunctionType; - _nullHandlingEnabled = nullHandlingEnabled; } @Override @@ -77,7 +75,7 @@ public GroupByResultHolder createGroupByResultHolder(int initialCapacity, int ma public Set extractAggregationResult(AggregationResultHolder aggregationResultHolder) { Object result = aggregationResultHolder.getResult(); if (result == null) { - // Use empty IntOpenHashSet as a place holder for empty result + // Use empty IntOpenHashSet as a placeholder for empty result return new IntOpenHashSet(); } @@ -146,25 +144,13 @@ protected static RoaringBitmap getDictIdBitmap(AggregationResultHolder aggregati protected void svAggregate(int length, AggregationResultHolder aggregationResultHolder, Map blockValSetMap) { BlockValSet blockValSet = blockValSetMap.get(_expression); - RoaringBitmap nullBitmap = null; // For dictionary-encoded expression, store dictionary ids into the bitmap Dictionary dictionary = blockValSet.getDictionary(); if (dictionary != null) { - if (_nullHandlingEnabled) { - nullBitmap = blockValSet.getNullBitmap(); - } int[] dictIds = blockValSet.getDictionaryIdsSV(); - if (nullBitmap != null && !nullBitmap.isEmpty()) { - RoaringBitmap dictIdBitmap = getDictIdBitmap(aggregationResultHolder, dictionary); - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { - dictIdBitmap.add(dictIds[i]); - } - } - } else { - getDictIdBitmap(aggregationResultHolder, dictionary).addN(dictIds, 0, length); - } + RoaringBitmap dictIdBitmap = getDictIdBitmap(aggregationResultHolder, dictionary); + forEachNotNull(length, blockValSet, (from, to) -> dictIdBitmap.addN(dictIds, from, to - from)); return; } @@ -175,111 +161,63 @@ protected void svAggregate(int length, AggregationResultHolder aggregationResult case INT: IntOpenHashSet intSet = (IntOpenHashSet) valueSet; int[] intValues = blockValSet.getIntValuesSV(); - if (_nullHandlingEnabled) { - nullBitmap = blockValSet.getNullBitmap(); - } - if (nullBitmap != null && !nullBitmap.isEmpty()) { - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { - intSet.add(intValues[i]); - } - } - } else { - for (int i = 0; i < length; i++) { + + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { intSet.add(intValues[i]); } - } + }); break; case LONG: LongOpenHashSet longSet = (LongOpenHashSet) valueSet; long[] longValues = blockValSet.getLongValuesSV(); - if (_nullHandlingEnabled) { - nullBitmap = blockValSet.getNullBitmap(); - } - if (nullBitmap != null && !nullBitmap.isEmpty()) { - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { - longSet.add(longValues[i]); - } - } - } else { - for (int i = 0; i < length; i++) { + + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { longSet.add(longValues[i]); } - } + }); break; case FLOAT: FloatOpenHashSet floatSet = (FloatOpenHashSet) valueSet; float[] floatValues = blockValSet.getFloatValuesSV(); - if (_nullHandlingEnabled) { - nullBitmap = blockValSet.getNullBitmap(); - } - if (nullBitmap != null && !nullBitmap.isEmpty()) { - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { - floatSet.add(floatValues[i]); - } - } - } else { - for (int i = 0; i < length; i++) { + + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { floatSet.add(floatValues[i]); } - } + }); break; case DOUBLE: DoubleOpenHashSet doubleSet = (DoubleOpenHashSet) valueSet; double[] doubleValues = blockValSet.getDoubleValuesSV(); - if (_nullHandlingEnabled) { - nullBitmap = blockValSet.getNullBitmap(); - } - if (nullBitmap != null && !nullBitmap.isEmpty()) { - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { - doubleSet.add(doubleValues[i]); - } - } - } else { - for (int i = 0; i < length; i++) { + + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { doubleSet.add(doubleValues[i]); } - } + }); break; case STRING: ObjectOpenHashSet stringSet = (ObjectOpenHashSet) valueSet; String[] stringValues = blockValSet.getStringValuesSV(); //noinspection ManualArrayToCollectionCopy - if (_nullHandlingEnabled) { - nullBitmap = blockValSet.getNullBitmap(); - } - if (nullBitmap != null && !nullBitmap.isEmpty()) { - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { - stringSet.add(stringValues[i]); - } - } - } else { - for (int i = 0; i < length; i++) { + + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { stringSet.add(stringValues[i]); } - } + }); break; case BYTES: ObjectOpenHashSet bytesSet = (ObjectOpenHashSet) valueSet; byte[][] bytesValues = blockValSet.getBytesValuesSV(); - if (_nullHandlingEnabled) { - nullBitmap = blockValSet.getNullBitmap(); - } - if (nullBitmap != null && !nullBitmap.isEmpty()) { - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { - bytesSet.add(new ByteArray(bytesValues[i])); - } - } - } else { - for (int i = 0; i < length; i++) { + + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { bytesSet.add(new ByteArray(bytesValues[i])); } - } + }); break; default: throw new IllegalStateException( @@ -368,26 +306,17 @@ protected void mvAggregate(int length, AggregationResultHolder aggregationResult protected void svAggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder, Map blockValSetMap) { BlockValSet blockValSet = blockValSetMap.get(_expression); - RoaringBitmap nullBitmap = null; // For dictionary-encoded expression, store dictionary ids into the bitmap Dictionary dictionary = blockValSet.getDictionary(); if (dictionary != null) { - if (_nullHandlingEnabled) { - nullBitmap = blockValSet.getNullBitmap(); - } int[] dictIds = blockValSet.getDictionaryIdsSV(); - if (nullBitmap != null && !nullBitmap.isEmpty()) { - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { - getDictIdBitmap(groupByResultHolder, groupKeyArray[i], dictionary).add(dictIds[i]); - } - } - } else { - for (int i = 0; i < length; i++) { + + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { getDictIdBitmap(groupByResultHolder, groupKeyArray[i], dictionary).add(dictIds[i]); } - } + }); return; } @@ -396,112 +325,60 @@ protected void svAggregateGroupBySV(int length, int[] groupKeyArray, GroupByResu switch (storedType) { case INT: int[] intValues = blockValSet.getIntValuesSV(); - if (_nullHandlingEnabled) { - nullBitmap = blockValSet.getNullBitmap(); - } - if (nullBitmap != null && !nullBitmap.isEmpty()) { - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { - ((IntOpenHashSet) getValueSet(groupByResultHolder, groupKeyArray[i], DataType.INT)).add(intValues[i]); - } - } - } else { - for (int i = 0; i < length; i++) { + + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { ((IntOpenHashSet) getValueSet(groupByResultHolder, groupKeyArray[i], DataType.INT)).add(intValues[i]); } - } + }); break; case LONG: long[] longValues = blockValSet.getLongValuesSV(); - if (_nullHandlingEnabled) { - nullBitmap = blockValSet.getNullBitmap(); - } - if (nullBitmap != null && !nullBitmap.isEmpty()) { - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { - ((LongOpenHashSet) getValueSet(groupByResultHolder, groupKeyArray[i], DataType.LONG)).add(longValues[i]); - } - } - } else { - for (int i = 0; i < length; i++) { + + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { ((LongOpenHashSet) getValueSet(groupByResultHolder, groupKeyArray[i], DataType.LONG)).add(longValues[i]); } - } + }); break; case FLOAT: float[] floatValues = blockValSet.getFloatValuesSV(); - if (_nullHandlingEnabled) { - nullBitmap = blockValSet.getNullBitmap(); - } - if (nullBitmap != null && !nullBitmap.isEmpty()) { - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { - ((FloatOpenHashSet) getValueSet(groupByResultHolder, groupKeyArray[i], DataType.FLOAT)).add( - floatValues[i]); - } - } - } else { - for (int i = 0; i < length; i++) { + + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { ((FloatOpenHashSet) getValueSet(groupByResultHolder, groupKeyArray[i], DataType.FLOAT)).add(floatValues[i]); } - } + }); break; case DOUBLE: double[] doubleValues = blockValSet.getDoubleValuesSV(); - if (_nullHandlingEnabled) { - nullBitmap = blockValSet.getNullBitmap(); - } - if (nullBitmap != null && !nullBitmap.isEmpty()) { - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { - ((DoubleOpenHashSet) getValueSet(groupByResultHolder, groupKeyArray[i], DataType.DOUBLE)).add( - doubleValues[i]); - } - } - } else { - for (int i = 0; i < length; i++) { - ((DoubleOpenHashSet) getValueSet(groupByResultHolder, groupKeyArray[i], DataType.DOUBLE)).add( - doubleValues[i]); + + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { + ((DoubleOpenHashSet) getValueSet(groupByResultHolder, groupKeyArray[i], DataType.DOUBLE)) + .add(doubleValues[i]); } - } + }); break; case STRING: String[] stringValues = blockValSet.getStringValuesSV(); - if (_nullHandlingEnabled) { - nullBitmap = blockValSet.getNullBitmap(); - } - if (nullBitmap != null && !nullBitmap.isEmpty()) { - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { - ((ObjectOpenHashSet) getValueSet(groupByResultHolder, groupKeyArray[i], DataType.STRING)).add( - stringValues[i]); - } - } - } else { - for (int i = 0; i < length; i++) { - ((ObjectOpenHashSet) getValueSet(groupByResultHolder, groupKeyArray[i], DataType.STRING)).add( - stringValues[i]); + + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { + ((ObjectOpenHashSet) getValueSet(groupByResultHolder, groupKeyArray[i], DataType.STRING)) + .add(stringValues[i]); } - } + }); break; case BYTES: byte[][] bytesValues = blockValSet.getBytesValuesSV(); - if (_nullHandlingEnabled) { - nullBitmap = blockValSet.getNullBitmap(); - } - if (nullBitmap != null && !nullBitmap.isEmpty()) { - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { - ((ObjectOpenHashSet) getValueSet(groupByResultHolder, groupKeyArray[i], DataType.BYTES)).add( - new ByteArray(bytesValues[i])); - } - } - } else { - for (int i = 0; i < length; i++) { - ((ObjectOpenHashSet) getValueSet(groupByResultHolder, groupKeyArray[i], DataType.BYTES)).add( - new ByteArray(bytesValues[i])); + + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { + ((ObjectOpenHashSet) getValueSet(groupByResultHolder, groupKeyArray[i], DataType.BYTES)) + .add(new ByteArray(bytesValues[i])); } - } + }); break; default: throw new IllegalStateException( @@ -591,26 +468,17 @@ protected void mvAggregateGroupBySV(int length, int[] groupKeyArray, GroupByResu protected void svAggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder, Map blockValSetMap) { BlockValSet blockValSet = blockValSetMap.get(_expression); - RoaringBitmap nullBitmap = null; // For dictionary-encoded expression, store dictionary ids into the bitmap Dictionary dictionary = blockValSet.getDictionary(); if (dictionary != null) { - if (_nullHandlingEnabled) { - nullBitmap = blockValSet.getNullBitmap(); - } int[] dictIds = blockValSet.getDictionaryIdsSV(); - if (nullBitmap != null && !nullBitmap.isEmpty()) { - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { - setDictIdForGroupKeys(groupByResultHolder, groupKeysArray[i], dictionary, dictIds[i]); - } - } - } else { - for (int i = 0; i < length; i++) { + + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { setDictIdForGroupKeys(groupByResultHolder, groupKeysArray[i], dictionary, dictIds[i]); } - } + }); return; } @@ -619,93 +487,57 @@ protected void svAggregateGroupByMV(int length, int[][] groupKeysArray, GroupByR switch (storedType) { case INT: int[] intValues = blockValSet.getIntValuesSV(); - nullBitmap = blockValSet.getNullBitmap(); - if (nullBitmap != null && !nullBitmap.isEmpty()) { - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { - setValueForGroupKeys(groupByResultHolder, groupKeysArray[i], intValues[i]); - } - } - } else { - for (int i = 0; i < length; i++) { + + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { setValueForGroupKeys(groupByResultHolder, groupKeysArray[i], intValues[i]); } - } + }); break; case LONG: long[] longValues = blockValSet.getLongValuesSV(); - nullBitmap = blockValSet.getNullBitmap(); - if (nullBitmap != null && !nullBitmap.isEmpty()) { - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { - setValueForGroupKeys(groupByResultHolder, groupKeysArray[i], longValues[i]); - } - } - } else { - for (int i = 0; i < length; i++) { + + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { setValueForGroupKeys(groupByResultHolder, groupKeysArray[i], longValues[i]); } - } + }); break; case FLOAT: float[] floatValues = blockValSet.getFloatValuesSV(); - nullBitmap = blockValSet.getNullBitmap(); - if (nullBitmap != null && !nullBitmap.isEmpty()) { - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { - setValueForGroupKeys(groupByResultHolder, groupKeysArray[i], floatValues[i]); - } - } - } else { - for (int i = 0; i < length; i++) { + + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { setValueForGroupKeys(groupByResultHolder, groupKeysArray[i], floatValues[i]); } - } + }); break; case DOUBLE: double[] doubleValues = blockValSet.getDoubleValuesSV(); - nullBitmap = blockValSet.getNullBitmap(); - if (nullBitmap != null && !nullBitmap.isEmpty()) { - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { - setValueForGroupKeys(groupByResultHolder, groupKeysArray[i], doubleValues[i]); - } - } - } else { - for (int i = 0; i < length; i++) { + + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { setValueForGroupKeys(groupByResultHolder, groupKeysArray[i], doubleValues[i]); } - } + }); break; case STRING: String[] stringValues = blockValSet.getStringValuesSV(); - nullBitmap = blockValSet.getNullBitmap(); - if (nullBitmap != null && !nullBitmap.isEmpty()) { - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { - setValueForGroupKeys(groupByResultHolder, groupKeysArray[i], stringValues[i]); - } - } - } else { - for (int i = 0; i < length; i++) { + + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { setValueForGroupKeys(groupByResultHolder, groupKeysArray[i], stringValues[i]); } - } + }); break; case BYTES: byte[][] bytesValues = blockValSet.getBytesValuesSV(); - nullBitmap = blockValSet.getNullBitmap(); - if (nullBitmap != null && !nullBitmap.isEmpty()) { - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { - setValueForGroupKeys(groupByResultHolder, groupKeysArray[i], new ByteArray(bytesValues[i])); - } - } - } else { - for (int i = 0; i < length; i++) { + + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { setValueForGroupKeys(groupByResultHolder, groupKeysArray[i], new ByteArray(bytesValues[i])); } - } + }); break; default: throw new IllegalStateException( diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CountAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CountAggregationFunction.java index b222803a442..982e673ca6d 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CountAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CountAggregationFunction.java @@ -33,27 +33,26 @@ import org.roaringbitmap.RoaringBitmap; -public class CountAggregationFunction extends BaseSingleInputAggregationFunction { +public class CountAggregationFunction extends NullableSingleInputAggregationFunction { private static final String COUNT_STAR_RESULT_COLUMN_NAME = "count(*)"; private static final double DEFAULT_INITIAL_VALUE = 0.0; // Special expression used by star-tree to pass in BlockValSet private static final ExpressionContext STAR_TREE_COUNT_STAR_EXPRESSION = ExpressionContext.forIdentifier(AggregationFunctionColumnPair.STAR); - private final boolean _nullHandlingEnabled; public CountAggregationFunction(List arguments, boolean nullHandlingEnabled) { - this(verifySingleArgument(arguments, "COUNT"), nullHandlingEnabled); - } - - protected CountAggregationFunction(ExpressionContext expression, boolean nullHandlingEnabled) { - super(expression); // Consider null values only when null handling is enabled and function is not COUNT(*) // Note COUNT on any literal gives same result as COUNT(*) // So allow for identifiers that are not * and functions, disable for literals and * - _nullHandlingEnabled = nullHandlingEnabled && ( - (expression.getType() == ExpressionContext.Type.IDENTIFIER && !expression.getIdentifier().equals("*")) || ( - expression.getType() == ExpressionContext.Type.FUNCTION)); + this(verifySingleArgument(arguments, "COUNT"), nullHandlingEnabled + && ((arguments.get(0).getType() == ExpressionContext.Type.IDENTIFIER + && !arguments.get(0).getIdentifier().equals("*")) + || (arguments.get(0).getType() == ExpressionContext.Type.FUNCTION))); + } + + protected CountAggregationFunction(ExpressionContext expression, boolean nullHandlingEnabled) { + super(expression, nullHandlingEnabled); } @Override @@ -126,23 +125,13 @@ public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHol // 0 | 1 assert blockValSetMap.size() == 1; BlockValSet blockValSet = blockValSetMap.values().iterator().next(); - RoaringBitmap nullBitmap = blockValSet.getNullBitmap(); - if (nullBitmap != null && !nullBitmap.isEmpty()) { - if (nullBitmap.getCardinality() == length) { - return; - } - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { - int groupKey = groupKeyArray[i]; - groupByResultHolder.setValueForKey(groupKey, groupByResultHolder.getDoubleResult(groupKey) + 1); - } - } - } else { - for (int i = 0; i < length; i++) { + + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { int groupKey = groupKeyArray[i]; groupByResultHolder.setValueForKey(groupKey, groupByResultHolder.getDoubleResult(groupKey) + 1); } - } + }); } else { // Star-tree pre-aggregated values long[] valueArray = blockValSetMap.get(STAR_TREE_COUNT_STAR_EXPRESSION).getLongValuesSV(); diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MaxAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MaxAggregationFunction.java index c2d37d35d9a..03f74bdd0b1 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MaxAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MaxAggregationFunction.java @@ -31,20 +31,17 @@ import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder; import org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder; import org.apache.pinot.segment.spi.AggregationFunctionType; -import org.roaringbitmap.RoaringBitmap; -public class MaxAggregationFunction extends BaseSingleInputAggregationFunction { +public class MaxAggregationFunction extends NullableSingleInputAggregationFunction { private static final double DEFAULT_INITIAL_VALUE = Double.NEGATIVE_INFINITY; - private final boolean _nullHandlingEnabled; public MaxAggregationFunction(List arguments, boolean nullHandlingEnabled) { this(verifySingleArgument(arguments, "MAX"), nullHandlingEnabled); } protected MaxAggregationFunction(ExpressionContext expression, boolean nullHandlingEnabled) { - super(expression); - _nullHandlingEnabled = nullHandlingEnabled; + super(expression, nullHandlingEnabled); } @Override @@ -72,61 +69,77 @@ public GroupByResultHolder createGroupByResultHolder(int initialCapacity, int ma public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map blockValSetMap) { BlockValSet blockValSet = blockValSetMap.get(_expression); - if (_nullHandlingEnabled) { - // TODO: avoid the null bitmap check when it is null or empty for better performance. - RoaringBitmap nullBitmap = blockValSet.getNullBitmap(); - if (nullBitmap == null) { - nullBitmap = new RoaringBitmap(); - } - aggregateNullHandlingEnabled(length, aggregationResultHolder, blockValSet, nullBitmap); - return; - } switch (blockValSet.getValueType().getStoredType()) { case INT: { int[] values = blockValSet.getIntValuesSV(); - int max = values[0]; - for (int i = 0; i < length & i < values.length; i++) { - max = Math.max(values[i], max); - } - aggregationResultHolder.setValue(Math.max(max, aggregationResultHolder.getDoubleResult())); + + Integer max = foldNotNull(length, blockValSet, null, (acum, from, to) -> { + int innerMax = values[from]; + for (int i = from; i < to; i++) { + innerMax = Math.max(innerMax, values[i]); + } + return acum == null ? innerMax : Math.max(acum, innerMax); + }); + + updateAggregationResultHolder(aggregationResultHolder, max); break; } case LONG: { long[] values = blockValSet.getLongValuesSV(); - long max = values[0]; - for (int i = 0; i < length & i < values.length; i++) { - max = Math.max(values[i], max); - } - aggregationResultHolder.setValue(Math.max(max, aggregationResultHolder.getDoubleResult())); + + Long max = foldNotNull(length, blockValSet, null, (acum, from, to) -> { + long innerMax = values[from]; + for (int i = from; i < to; i++) { + innerMax = Math.max(innerMax, values[i]); + } + return acum == null ? innerMax : Math.max(acum, innerMax); + }); + + updateAggregationResultHolder(aggregationResultHolder, max); break; } case FLOAT: { float[] values = blockValSet.getFloatValuesSV(); - float max = values[0]; - for (int i = 0; i < length & i < values.length; i++) { - max = Math.max(values[i], max); - } - aggregationResultHolder.setValue(Math.max(max, aggregationResultHolder.getDoubleResult())); + + Float max = foldNotNull(length, blockValSet, null, (acum, from, to) -> { + float innerMax = values[from]; + for (int i = from; i < to; i++) { + innerMax = Math.max(innerMax, values[i]); + } + return acum == null ? innerMax : Math.max(acum, innerMax); + }); + + updateAggregationResultHolder(aggregationResultHolder, max); break; } case DOUBLE: { double[] values = blockValSet.getDoubleValuesSV(); - double max = values[0]; - for (int i = 0; i < length & i < values.length; i++) { - max = Math.max(values[i], max); - } - aggregationResultHolder.setValue(Math.max(max, aggregationResultHolder.getDoubleResult())); + + Double max = foldNotNull(length, blockValSet, null, (acum, from, to) -> { + double innerMax = values[from]; + for (int i = from; i < to; i++) { + innerMax = Math.max(innerMax, values[i]); + } + return acum == null ? innerMax : Math.max(acum, innerMax); + }); + + updateAggregationResultHolder(aggregationResultHolder, max); break; } case BIG_DECIMAL: { BigDecimal[] values = blockValSet.getBigDecimalValuesSV(); - BigDecimal max = values[0]; - for (int i = 0; i < length & i < values.length; i++) { - max = values[i].max(max); - } + + BigDecimal max = foldNotNull(length, blockValSet, null, (acum, from, to) -> { + BigDecimal innerMax = values[from]; + for (int i = from; i < to; i++) { + innerMax = innerMax.max(values[i]); + } + return acum == null ? innerMax : acum.max(innerMax); + }); + // TODO: even though the source data has BIG_DECIMAL type, we still only support double precision. - aggregationResultHolder.setValue(Math.max(max.doubleValue(), aggregationResultHolder.getDoubleResult())); + updateAggregationResultHolder(aggregationResultHolder, max); break; } default: @@ -134,117 +147,42 @@ public void aggregate(int length, AggregationResultHolder aggregationResultHolde } } - private void aggregateNullHandlingEnabled(int length, AggregationResultHolder aggregationResultHolder, - BlockValSet blockValSet, RoaringBitmap nullBitmap) { - switch (blockValSet.getValueType().getStoredType()) { - case INT: { - if (nullBitmap.getCardinality() < length) { - int[] values = blockValSet.getIntValuesSV(); - int max = Integer.MIN_VALUE; - for (int i = 0; i < length & i < values.length; i++) { - if (!nullBitmap.contains(i)) { - max = Math.max(values[i], max); - } - } - updateAggregationResultHolder(aggregationResultHolder, max); - } - // Note: when all input values re null (nullBitmap.getCardinality() == values.length), max is null. As a result, - // we don't update the value of aggregationResultHolder. - break; - } - case LONG: { - if (nullBitmap.getCardinality() < length) { - long[] values = blockValSet.getLongValuesSV(); - long max = Long.MIN_VALUE; - for (int i = 0; i < length & i < values.length; i++) { - if (!nullBitmap.contains(i)) { - max = Math.max(values[i], max); - } - } - updateAggregationResultHolder(aggregationResultHolder, max); - } - break; - } - case FLOAT: { - if (nullBitmap.getCardinality() < length) { - float[] values = blockValSet.getFloatValuesSV(); - float max = Float.NEGATIVE_INFINITY; - for (int i = 0; i < length & i < values.length; i++) { - if (!nullBitmap.contains(i)) { - max = Math.max(values[i], max); - } - } - updateAggregationResultHolder(aggregationResultHolder, max); - } - break; - } - case DOUBLE: { - if (nullBitmap.getCardinality() < length) { - double[] values = blockValSet.getDoubleValuesSV(); - double max = Double.NEGATIVE_INFINITY; - for (int i = 0; i < length & i < values.length; i++) { - if (!nullBitmap.contains(i)) { - max = Math.max(values[i], max); - } - } - updateAggregationResultHolder(aggregationResultHolder, max); - } - break; - } - case BIG_DECIMAL: { - if (nullBitmap.getCardinality() < length) { - BigDecimal[] values = blockValSet.getBigDecimalValuesSV(); - BigDecimal max = null; - for (int i = 0; i < length & i < values.length; i++) { - if (!nullBitmap.contains(i)) { - max = max == null ? values[i] : values[i].max(max); - } - } - // TODO: even though the source data has BIG_DECIMAL type, we still only support double precision. - assert max != null; - updateAggregationResultHolder(aggregationResultHolder, max.doubleValue()); - } - break; + private void updateAggregationResultHolder(AggregationResultHolder aggregationResultHolder, Number max) { + if (max != null) { + if (_nullHandlingEnabled) { + Double otherMax = aggregationResultHolder.getResult(); + aggregationResultHolder.setValue(otherMax == null ? max.doubleValue() : Math.max(max.doubleValue(), otherMax)); + } else { + double otherMax = aggregationResultHolder.getDoubleResult(); + aggregationResultHolder.setValue(Math.max(max.doubleValue(), otherMax)); } - default: - throw new IllegalStateException("Cannot compute max for non-numeric type: " + blockValSet.getValueType()); } } - private void updateAggregationResultHolder(AggregationResultHolder aggregationResultHolder, double max) { - Double otherMax = aggregationResultHolder.getResult(); - aggregationResultHolder.setValue(otherMax == null ? max : Math.max(max, otherMax)); - } - @Override public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder, Map blockValSetMap) { BlockValSet blockValSet = blockValSetMap.get(_expression); + double[] valueArray = blockValSet.getDoubleValuesSV(); + if (_nullHandlingEnabled) { - RoaringBitmap nullBitmap = blockValSet.getNullBitmap(); - if (nullBitmap == null) { - nullBitmap = new RoaringBitmap(); - } - if (nullBitmap.getCardinality() < length) { - double[] valueArray = blockValSet.getDoubleValuesSV(); - for (int i = 0; i < length; i++) { + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { double value = valueArray[i]; int groupKey = groupKeyArray[i]; Double result = groupByResultHolder.getResult(groupKey); - if (!nullBitmap.contains(i) && (result == null || value > result)) { + if (result == null || value > result) { groupByResultHolder.setValueForKey(groupKey, value); } } - } - return; - } - - double[] valueArray = blockValSet.getDoubleValuesSV(); - for (int i = 0; i < length; i++) { - double value = valueArray[i]; - int groupKey = groupKeyArray[i]; - if (value > groupByResultHolder.getDoubleResult(groupKey)) { - groupByResultHolder.setValueForKey(groupKey, value); + }); + } else { + for (int i = 0; i < length; i++) { + double value = valueArray[i]; + int groupKey = groupKeyArray[i]; + if (value > groupByResultHolder.getDoubleResult(groupKey)) { + groupByResultHolder.setValueForKey(groupKey, value); + } } } } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinAggregationFunction.java index a74b7a53ee0..b8484b18f47 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinAggregationFunction.java @@ -31,20 +31,17 @@ import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder; import org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder; import org.apache.pinot.segment.spi.AggregationFunctionType; -import org.roaringbitmap.RoaringBitmap; -public class MinAggregationFunction extends BaseSingleInputAggregationFunction { +public class MinAggregationFunction extends NullableSingleInputAggregationFunction { private static final double DEFAULT_VALUE = Double.POSITIVE_INFINITY; - private final boolean _nullHandlingEnabled; public MinAggregationFunction(List arguments, boolean nullHandlingEnabled) { this(verifySingleArgument(arguments, "MIN"), nullHandlingEnabled); } protected MinAggregationFunction(ExpressionContext expression, boolean nullHandlingEnabled) { - super(expression); - _nullHandlingEnabled = nullHandlingEnabled; + super(expression, nullHandlingEnabled); } @Override @@ -72,60 +69,77 @@ public GroupByResultHolder createGroupByResultHolder(int initialCapacity, int ma public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map blockValSetMap) { BlockValSet blockValSet = blockValSetMap.get(_expression); - if (_nullHandlingEnabled) { - RoaringBitmap nullBitmap = blockValSet.getNullBitmap(); - if (nullBitmap == null) { - nullBitmap = new RoaringBitmap(); - } - aggregateNullHandlingEnabled(length, aggregationResultHolder, blockValSet, nullBitmap); - return; - } switch (blockValSet.getValueType().getStoredType()) { case INT: { int[] values = blockValSet.getIntValuesSV(); - int min = values[0]; - for (int i = 0; i < length & i < values.length; i++) { - min = Math.min(values[i], min); - } - aggregationResultHolder.setValue(Math.min(min, aggregationResultHolder.getDoubleResult())); + + Integer min = foldNotNull(length, blockValSet, null, (acum, from, to) -> { + int innerMin = values[from]; + for (int i = from; i < to; i++) { + innerMin = Math.min(innerMin, values[i]); + } + return acum == null ? innerMin : Math.min(acum, innerMin); + }); + + updateAggregationResultHolder(aggregationResultHolder, min); break; } case LONG: { long[] values = blockValSet.getLongValuesSV(); - long min = values[0]; - for (int i = 0; i < length & i < values.length; i++) { - min = Math.min(values[i], min); - } - aggregationResultHolder.setValue(Math.min(min, aggregationResultHolder.getDoubleResult())); + + Long min = foldNotNull(length, blockValSet, null, (acum, from, to) -> { + long innerMin = values[from]; + for (int i = from; i < to; i++) { + innerMin = Math.min(innerMin, values[i]); + } + return acum == null ? innerMin : Math.min(acum, innerMin); + }); + + updateAggregationResultHolder(aggregationResultHolder, min); break; } case FLOAT: { float[] values = blockValSet.getFloatValuesSV(); - float min = values[0]; - for (int i = 0; i < length & i < values.length; i++) { - min = Math.min(values[i], min); - } - aggregationResultHolder.setValue(Math.min(min, aggregationResultHolder.getDoubleResult())); + + Float min = foldNotNull(length, blockValSet, null, (acum, from, to) -> { + float innerMin = values[from]; + for (int i = from; i < to; i++) { + innerMin = Math.min(innerMin, values[i]); + } + return acum == null ? innerMin : Math.min(acum, innerMin); + }); + + updateAggregationResultHolder(aggregationResultHolder, min); break; } case DOUBLE: { double[] values = blockValSet.getDoubleValuesSV(); - double min = values[0]; - for (int i = 0; i < length & i < values.length; i++) { - min = Math.min(values[i], min); - } - aggregationResultHolder.setValue(Math.min(min, aggregationResultHolder.getDoubleResult())); + + Double min = foldNotNull(length, blockValSet, null, (acum, from, to) -> { + double innerMin = values[from]; + for (int i = from; i < to; i++) { + innerMin = Math.min(innerMin, values[i]); + } + return acum == null ? innerMin : Math.min(acum, innerMin); + }); + + updateAggregationResultHolder(aggregationResultHolder, min); break; } case BIG_DECIMAL: { BigDecimal[] values = blockValSet.getBigDecimalValuesSV(); - BigDecimal min = values[0]; - for (int i = 0; i < length & i < values.length; i++) { - min = values[i].min(min); - } + + BigDecimal min = foldNotNull(length, blockValSet, null, (acum, from, to) -> { + BigDecimal innerMin = values[from]; + for (int i = from; i < to; i++) { + innerMin = innerMin.min(values[i]); + } + return acum == null ? innerMin : acum.min(innerMin); + }); + // TODO: even though the source data has BIG_DECIMAL type, we still only support double precision. - aggregationResultHolder.setValue(Math.min(min.doubleValue(), aggregationResultHolder.getDoubleResult())); + updateAggregationResultHolder(aggregationResultHolder, min); break; } default: @@ -133,117 +147,42 @@ public void aggregate(int length, AggregationResultHolder aggregationResultHolde } } - private void aggregateNullHandlingEnabled(int length, AggregationResultHolder aggregationResultHolder, - BlockValSet blockValSet, RoaringBitmap nullBitmap) { - switch (blockValSet.getValueType().getStoredType()) { - case INT: { - if (nullBitmap.getCardinality() < length) { - int[] values = blockValSet.getIntValuesSV(); - int min = Integer.MAX_VALUE; - for (int i = 0; i < length & i < values.length; i++) { - if (!nullBitmap.contains(i)) { - min = Math.min(values[i], min); - } - } - updateAggregationResultHolder(aggregationResultHolder, min); - } - // Note: when all input values re null (nullBitmap.getCardinality() == values.length), min is null. As a result, - // we don't update the value of aggregationResultHolder. - break; - } - case LONG: { - if (nullBitmap.getCardinality() < length) { - long[] values = blockValSet.getLongValuesSV(); - long min = Long.MAX_VALUE; - for (int i = 0; i < length & i < values.length; i++) { - if (!nullBitmap.contains(i)) { - min = Math.min(values[i], min); - } - } - updateAggregationResultHolder(aggregationResultHolder, min); - } - break; - } - case FLOAT: { - if (nullBitmap.getCardinality() < length) { - float[] values = blockValSet.getFloatValuesSV(); - float min = Float.POSITIVE_INFINITY; - for (int i = 0; i < length & i < values.length; i++) { - if (!nullBitmap.contains(i)) { - min = Math.min(values[i], min); - } - } - updateAggregationResultHolder(aggregationResultHolder, min); - } - break; - } - case DOUBLE: { - if (nullBitmap.getCardinality() < length) { - double[] values = blockValSet.getDoubleValuesSV(); - double min = Double.POSITIVE_INFINITY; - for (int i = 0; i < length & i < values.length; i++) { - if (!nullBitmap.contains(i)) { - min = Math.min(values[i], min); - } - } - updateAggregationResultHolder(aggregationResultHolder, min); - } - break; - } - case BIG_DECIMAL: { - if (nullBitmap.getCardinality() < length) { - BigDecimal[] values = blockValSet.getBigDecimalValuesSV(); - BigDecimal min = null; - for (int i = 0; i < length & i < values.length; i++) { - if (!nullBitmap.contains(i)) { - min = min == null ? values[i] : values[i].min(min); - } - } - assert min != null; - // TODO: even though the source data has BIG_DECIMAL type, we still only support double precision. - updateAggregationResultHolder(aggregationResultHolder, min.doubleValue()); - } - break; + private void updateAggregationResultHolder(AggregationResultHolder aggregationResultHolder, Number min) { + if (min != null) { + if (_nullHandlingEnabled) { + Double otherMin = aggregationResultHolder.getResult(); + aggregationResultHolder.setValue(otherMin == null ? min.doubleValue() : Math.min(min.doubleValue(), otherMin)); + } else { + double otherMin = aggregationResultHolder.getDoubleResult(); + aggregationResultHolder.setValue(Math.min(min.doubleValue(), otherMin)); } - default: - throw new IllegalStateException("Cannot compute min for non-numeric type: " + blockValSet.getValueType()); } } - private void updateAggregationResultHolder(AggregationResultHolder aggregationResultHolder, double min) { - Double otherMin = aggregationResultHolder.getResult(); - aggregationResultHolder.setValue(otherMin == null ? min : Math.min(min, otherMin)); - } - @Override public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder, Map blockValSetMap) { BlockValSet blockValSet = blockValSetMap.get(_expression); + double[] valueArray = blockValSet.getDoubleValuesSV(); + if (_nullHandlingEnabled) { - RoaringBitmap nullBitmap = blockValSet.getNullBitmap(); - if (nullBitmap == null) { - nullBitmap = new RoaringBitmap(); - } - if (nullBitmap.getCardinality() < length) { - double[] valueArray = blockValSet.getDoubleValuesSV(); - for (int i = 0; i < length; i++) { + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { double value = valueArray[i]; int groupKey = groupKeyArray[i]; Double result = groupByResultHolder.getResult(groupKey); - if (!nullBitmap.contains(i) && (result == null || value < result)) { + if (result == null || value < result) { groupByResultHolder.setValueForKey(groupKey, value); } } - } - return; - } - - double[] valueArray = blockValSet.getDoubleValuesSV(); - for (int i = 0; i < length; i++) { - double value = valueArray[i]; - int groupKey = groupKeyArray[i]; - if (value < groupByResultHolder.getDoubleResult(groupKey)) { - groupByResultHolder.setValueForKey(groupKey, value); + }); + } else { + for (int i = 0; i < length; i++) { + double value = valueArray[i]; + int groupKey = groupKeyArray[i]; + if (value < groupByResultHolder.getDoubleResult(groupKey)) { + groupByResultHolder.setValueForKey(groupKey, value); + } } } } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/NullableSingleInputAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/NullableSingleInputAggregationFunction.java index af2a41610e0..a9e06725d10 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/NullableSingleInputAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/NullableSingleInputAggregationFunction.java @@ -107,16 +107,19 @@ public void forEachNotNull(int length, IntIterator nullIndexIterator, BatchConsu } /** - * Folds over the non-null ranges of the blockValSet using the reducer. + * Folds over the non-null ranges of the blockValSet using the reducer. Returns {@code initialAcum} if the entire + * block is null. + * * @param initialAcum the initial value of the accumulator * @param The type of the accumulator */ public A foldNotNull(int length, BlockValSet blockValSet, A initialAcum, Reducer reducer) { - return foldNotNull(length, blockValSet.getNullBitmap(), initialAcum, reducer); + return foldNotNull(length, _nullHandlingEnabled ? blockValSet.getNullBitmap() : null, initialAcum, reducer); } /** - * Folds over the non-null ranges of the blockValSet using the reducer. + * Folds over the non-null ranges of the blockValSet using the reducer. Returns {@code initialAcum} if the entire + * block is null. * @param initialAcum the initial value of the accumulator * @param The type of the accumulator */ @@ -139,6 +142,11 @@ public A foldNotNull(int length, @Nullable RoaringBitmap roaringBitmap, A in */ public A foldNotNull(int length, @Nullable IntIterator nullIndexIterator, A initialAcum, Reducer reducer) { A acum = initialAcum; + + if (length == 0) { + return acum; + } + if (!_nullHandlingEnabled || nullIndexIterator == null || !nullIndexIterator.hasNext()) { return reducer.apply(initialAcum, 0, length); } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumAggregationFunction.java index b90dcc20510..2c44c783aba 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumAggregationFunction.java @@ -31,20 +31,17 @@ import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder; import org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder; import org.apache.pinot.segment.spi.AggregationFunctionType; -import org.roaringbitmap.RoaringBitmap; -public class SumAggregationFunction extends BaseSingleInputAggregationFunction { +public class SumAggregationFunction extends NullableSingleInputAggregationFunction { private static final double DEFAULT_VALUE = 0.0; - private final boolean _nullHandlingEnabled; public SumAggregationFunction(List arguments, boolean nullHandlingEnabled) { this(verifySingleArgument(arguments, "SUM"), nullHandlingEnabled); } protected SumAggregationFunction(ExpressionContext expression, boolean nullHandlingEnabled) { - super(expression); - _nullHandlingEnabled = nullHandlingEnabled; + super(expression, nullHandlingEnabled); } @Override @@ -72,170 +69,112 @@ public GroupByResultHolder createGroupByResultHolder(int initialCapacity, int ma public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map blockValSetMap) { BlockValSet blockValSet = blockValSetMap.get(_expression); - if (_nullHandlingEnabled) { - RoaringBitmap nullBitmap = blockValSet.getNullBitmap(); - if (nullBitmap == null) { - nullBitmap = new RoaringBitmap(); - } - aggregateNullHandlingEnabled(length, aggregationResultHolder, blockValSet, nullBitmap); - return; - } - double sum = aggregationResultHolder.getDoubleResult(); + Double sum; switch (blockValSet.getValueType().getStoredType()) { case INT: { int[] values = blockValSet.getIntValuesSV(); - for (int i = 0; i < length & i < values.length; i++) { - sum += values[i]; - } + + sum = foldNotNull(length, blockValSet, null, (acum, from, to) -> { + double innerSum = 0; + for (int i = from; i < to; i++) { + innerSum += values[i]; + } + return acum == null ? innerSum : acum + innerSum; + }); + break; } case LONG: { long[] values = blockValSet.getLongValuesSV(); - for (int i = 0; i < length & i < values.length; i++) { - sum += values[i]; - } + + sum = foldNotNull(length, blockValSet, null, (acum, from, to) -> { + double innerSum = 0; + for (int i = from; i < to; i++) { + innerSum += values[i]; + } + return acum == null ? innerSum : acum + innerSum; + }); + break; } case FLOAT: { float[] values = blockValSet.getFloatValuesSV(); - for (int i = 0; i < length & i < values.length; i++) { - sum += values[i]; - } + + sum = foldNotNull(length, blockValSet, null, (acum, from, to) -> { + double innerSum = 0; + for (int i = from; i < to; i++) { + innerSum += values[i]; + } + return acum == null ? innerSum : acum + innerSum; + }); + break; } case DOUBLE: { double[] values = blockValSet.getDoubleValuesSV(); - for (int i = 0; i < length & i < values.length; i++) { - sum += values[i]; - } + + sum = foldNotNull(length, blockValSet, null, (acum, from, to) -> { + double innerSum = 0; + for (int i = from; i < to; i++) { + innerSum += values[i]; + } + return acum == null ? innerSum : acum + innerSum; + }); + break; } case BIG_DECIMAL: { - BigDecimal decimalSum = BigDecimal.valueOf(sum); BigDecimal[] values = blockValSet.getBigDecimalValuesSV(); - for (int i = 0; i < length & i < values.length; i++) { - decimalSum = decimalSum.add(values[i]); - } + + BigDecimal decimalSum = foldNotNull(length, blockValSet, null, (acum, from, to) -> { + BigDecimal innerSum = BigDecimal.ZERO; + for (int i = from; i < to; i++) { + innerSum = innerSum.add(values[i]); + } + return acum == null ? innerSum : acum.add(innerSum); + }); // TODO: even though the source data has BIG_DECIMAL type, we still only support double precision. - sum = decimalSum.doubleValue(); + sum = decimalSum == null ? null : decimalSum.doubleValue(); break; } default: throw new IllegalStateException("Cannot compute sum for non-numeric type: " + blockValSet.getValueType()); } - aggregationResultHolder.setValue(sum); + updateAggregationResultHolder(aggregationResultHolder, sum); } - private void aggregateNullHandlingEnabled(int length, AggregationResultHolder aggregationResultHolder, - BlockValSet blockValSet, RoaringBitmap nullBitmap) { - double sum = 0; - switch (blockValSet.getValueType().getStoredType()) { - case INT: { - if (nullBitmap.getCardinality() < length) { - int[] values = blockValSet.getIntValuesSV(); - for (int i = 0; i < length & i < values.length; i++) { - if (!nullBitmap.contains(i)) { - sum += values[i]; - } - } - setAggregationResultHolder(aggregationResultHolder, sum); - } - break; + private void updateAggregationResultHolder(AggregationResultHolder aggregationResultHolder, Double sum) { + if (sum != null) { + if (_nullHandlingEnabled) { + Double otherSum = aggregationResultHolder.getResult(); + aggregationResultHolder.setValue(otherSum == null ? sum : sum + otherSum); + } else { + double otherSum = aggregationResultHolder.getDoubleResult(); + aggregationResultHolder.setValue(sum + otherSum); } - case LONG: { - if (nullBitmap.getCardinality() < length) { - long[] values = blockValSet.getLongValuesSV(); - for (int i = 0; i < length & i < values.length; i++) { - if (!nullBitmap.contains(i)) { - sum += values[i]; - } - } - setAggregationResultHolder(aggregationResultHolder, sum); - } - break; - } - case FLOAT: { - if (nullBitmap.getCardinality() < length) { - float[] values = blockValSet.getFloatValuesSV(); - for (int i = 0; i < length & i < values.length; i++) { - if (!nullBitmap.contains(i)) { - sum += values[i]; - } - } - setAggregationResultHolder(aggregationResultHolder, sum); - } - break; - } - case DOUBLE: { - if (nullBitmap.getCardinality() < length) { - double[] values = blockValSet.getDoubleValuesSV(); - for (int i = 0; i < length & i < values.length; i++) { - if (!nullBitmap.contains(i)) { - sum += values[i]; - } - } - setAggregationResultHolder(aggregationResultHolder, sum); - } - break; - } - case BIG_DECIMAL: { - if (nullBitmap.getCardinality() < length) { - BigDecimal[] values = blockValSet.getBigDecimalValuesSV(); - BigDecimal decimalSum = BigDecimal.valueOf(sum); - for (int i = 0; i < length & i < values.length; i++) { - if (!nullBitmap.contains(i)) { - decimalSum = decimalSum.add(values[i]); - } - } - // TODO: even though the source data has BIG_DECIMAL type, we still only support double precision. - setAggregationResultHolder(aggregationResultHolder, decimalSum.doubleValue()); - } - break; - } - default: - throw new IllegalStateException("Cannot compute sum for non-numeric type: " + blockValSet.getValueType()); } } - private void setAggregationResultHolder(AggregationResultHolder aggregationResultHolder, double sum) { - Double otherSum = aggregationResultHolder.getResult(); - aggregationResultHolder.setValue(otherSum == null ? sum : sum + otherSum); - } - @Override public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder, Map blockValSetMap) { BlockValSet blockValSet = blockValSetMap.get(_expression); + double[] valueArray = blockValSet.getDoubleValuesSV(); + if (_nullHandlingEnabled) { - RoaringBitmap nullBitmap = blockValSet.getNullBitmap(); - if (nullBitmap == null) { - nullBitmap = new RoaringBitmap(); - } - if (nullBitmap.getCardinality() < length) { - double[] valueArray = blockValSet.getDoubleValuesSV(); - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { - int groupKey = groupKeyArray[i]; - Double result = groupByResultHolder.getResult(groupKey); - groupByResultHolder.setValueForKey(groupKey, result == null ? valueArray[i] : result + valueArray[i]); - // In presto: - // SELECT sum (cast(id AS DOUBLE)) as sum, min(id) as min, max(id) as max, key FROM (VALUES (null, 1), - // (null, 2)) AS t(id, key) GROUP BY key ORDER BY max DESC; - // sum | min | max | key - //------+------+------+----- - // NULL | NULL | NULL | 2 - // NULL | NULL | NULL | 1 - } + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { + int groupKey = groupKeyArray[i]; + Double result = groupByResultHolder.getResult(groupKey); + groupByResultHolder.setValueForKey(groupKey, result == null ? valueArray[i] : result + valueArray[i]); } + }); + } else { + for (int i = 0; i < length; i++) { + int groupKey = groupKeyArray[i]; + groupByResultHolder.setValueForKey(groupKey, groupByResultHolder.getDoubleResult(groupKey) + valueArray[i]); } - return; - } - - double[] valueArray = blockValSet.getDoubleValuesSV(); - for (int i = 0; i < length; i++) { - int groupKey = groupKeyArray[i]; - groupByResultHolder.setValueForKey(groupKey, groupByResultHolder.getDoubleResult(groupKey) + valueArray[i]); } } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumPrecisionAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumPrecisionAggregationFunction.java index 2bad736974c..c1b07881257 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumPrecisionAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumPrecisionAggregationFunction.java @@ -33,7 +33,6 @@ import org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder; import org.apache.pinot.segment.spi.AggregationFunctionType; import org.apache.pinot.spi.utils.BigDecimalUtils; -import org.roaringbitmap.RoaringBitmap; /** @@ -46,13 +45,12 @@ *
  • Scale (optional): scale to be set to the final result
  • * */ -public class SumPrecisionAggregationFunction extends BaseSingleInputAggregationFunction { +public class SumPrecisionAggregationFunction extends NullableSingleInputAggregationFunction { private final Integer _precision; private final Integer _scale; - private final boolean _nullHandlingEnabled; public SumPrecisionAggregationFunction(List arguments, boolean nullHandlingEnabled) { - super(arguments.get(0)); + super(arguments.get(0), nullHandlingEnabled); int numArguments = arguments.size(); Preconditions.checkArgument(numArguments <= 3, "SumPrecision expects at most 3 arguments, got: %s", numArguments); @@ -67,7 +65,6 @@ public SumPrecisionAggregationFunction(List arguments, boolea _precision = null; _scale = null; } - _nullHandlingEnabled = nullHandlingEnabled; } @Override @@ -89,281 +86,156 @@ public GroupByResultHolder createGroupByResultHolder(int initialCapacity, int ma public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map blockValSetMap) { BlockValSet blockValSet = blockValSetMap.get(_expression); - if (_nullHandlingEnabled) { - RoaringBitmap nullBitmap = blockValSet.getNullBitmap(); - if (nullBitmap != null && !nullBitmap.isEmpty()) { - aggregateNullHandlingEnabled(length, aggregationResultHolder, blockValSet, nullBitmap); - return; - } - } - BigDecimal sum = getDefaultResult(aggregationResultHolder); + BigDecimal sum; switch (blockValSet.getValueType().getStoredType()) { case INT: int[] intValues = blockValSet.getIntValuesSV(); - for (int i = 0; i < length; i++) { - sum = sum.add(BigDecimal.valueOf(intValues[i])); - } + + sum = foldNotNull(length, blockValSet, null, (acum, from, to) -> { + BigDecimal innerSum = BigDecimal.ZERO; + for (int i = from; i < to; i++) { + innerSum = innerSum.add(BigDecimal.valueOf(intValues[i])); + } + return acum == null ? innerSum : acum.add(innerSum); + }); + break; case LONG: long[] longValues = blockValSet.getLongValuesSV(); - for (int i = 0; i < length; i++) { - sum = sum.add(BigDecimal.valueOf(longValues[i])); - } + + sum = foldNotNull(length, blockValSet, null, (acum, from, to) -> { + BigDecimal innerSum = BigDecimal.ZERO; + for (int i = from; i < to; i++) { + innerSum = innerSum.add(BigDecimal.valueOf(longValues[i])); + } + return acum == null ? innerSum : acum.add(innerSum); + }); + break; case FLOAT: case DOUBLE: case STRING: String[] stringValues = blockValSet.getStringValuesSV(); - for (int i = 0; i < length; i++) { - sum = sum.add(new BigDecimal(stringValues[i])); - } + + sum = foldNotNull(length, blockValSet, null, (acum, from, to) -> { + BigDecimal innerSum = BigDecimal.ZERO; + for (int i = from; i < to; i++) { + innerSum = innerSum.add(new BigDecimal(stringValues[i])); + } + return acum == null ? innerSum : acum.add(innerSum); + }); + break; case BIG_DECIMAL: BigDecimal[] bigDecimalValues = blockValSet.getBigDecimalValuesSV(); - for (int i = 0; i < length; i++) { - sum = sum.add(bigDecimalValues[i]); - } + + sum = foldNotNull(length, blockValSet, null, (acum, from, to) -> { + BigDecimal innerSum = BigDecimal.ZERO; + for (int i = from; i < to; i++) { + innerSum = innerSum.add(bigDecimalValues[i]); + } + return acum == null ? innerSum : acum.add(innerSum); + }); + break; case BYTES: byte[][] bytesValues = blockValSet.getBytesValuesSV(); - for (int i = 0; i < length; i++) { - sum = sum.add(BigDecimalUtils.deserialize(bytesValues[i])); - } + + sum = foldNotNull(length, blockValSet, null, (acum, from, to) -> { + BigDecimal innerSum = BigDecimal.ZERO; + for (int i = from; i < to; i++) { + innerSum = innerSum.add(BigDecimalUtils.deserialize(bytesValues[i])); + } + return acum == null ? innerSum : acum.add(innerSum); + }); + break; default: throw new IllegalStateException(); } - aggregationResultHolder.setValue(sum); + updateAggregationResult(aggregationResultHolder, sum); } - private void aggregateNullHandlingEnabled(int length, AggregationResultHolder aggregationResultHolder, - BlockValSet blockValSet, RoaringBitmap nullBitmap) { - BigDecimal sum = BigDecimal.ZERO; - switch (blockValSet.getValueType().getStoredType()) { - case INT: { - if (nullBitmap.getCardinality() < length) { - int[] intValues = blockValSet.getIntValuesSV(); - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { - sum = sum.add(BigDecimal.valueOf(intValues[i])); - } - } - setAggregationResult(aggregationResultHolder, sum); - } - break; - } - case LONG: { - if (nullBitmap.getCardinality() < length) { - long[] longValues = blockValSet.getLongValuesSV(); - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { - sum = sum.add(BigDecimal.valueOf(longValues[i])); - } - } - setAggregationResult(aggregationResultHolder, sum); - } - break; - } - case FLOAT: { - if (nullBitmap.getCardinality() < length) { - float[] floatValues = blockValSet.getFloatValuesSV(); - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { - if (Float.isFinite(floatValues[i])) { - sum = sum.add(BigDecimal.valueOf(floatValues[i])); - } - } - } - setAggregationResult(aggregationResultHolder, sum); - } - break; - } - case DOUBLE: { - if (nullBitmap.getCardinality() < length) { - double[] doubleValues = blockValSet.getDoubleValuesSV(); - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { - // TODO(nhejazi): throw an exception here instead of ignoring infinite values? - if (Double.isFinite(doubleValues[i])) { - sum = sum.add(BigDecimal.valueOf(doubleValues[i])); - } - } - } - setAggregationResult(aggregationResultHolder, sum); - } - break; + protected void updateAggregationResult(AggregationResultHolder aggregationResultHolder, BigDecimal sum) { + if (_nullHandlingEnabled) { + if (sum != null) { + BigDecimal otherSum = aggregationResultHolder.getResult(); + aggregationResultHolder.setValue(otherSum == null ? sum : sum.add(otherSum)); } - case STRING: - if (nullBitmap.getCardinality() < length) { - String[] stringValues = blockValSet.getStringValuesSV(); - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { - sum = sum.add(new BigDecimal(stringValues[i])); - } - } - setAggregationResult(aggregationResultHolder, sum); - } - break; - case BIG_DECIMAL: { - if (nullBitmap.getCardinality() < length) { - BigDecimal[] bigDecimalValues = blockValSet.getBigDecimalValuesSV(); - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { - sum = sum.add(bigDecimalValues[i]); - } - } - setAggregationResult(aggregationResultHolder, sum); - } - break; + } else { + if (sum == null) { + sum = BigDecimal.ZERO; } - case BYTES: - if (nullBitmap.getCardinality() < length) { - byte[][] bytesValues = blockValSet.getBytesValuesSV(); - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { - sum = sum.add(BigDecimalUtils.deserialize(bytesValues[i])); - } - } - setAggregationResult(aggregationResultHolder, sum); - } - break; - default: - throw new IllegalStateException(); + BigDecimal otherSum = aggregationResultHolder.getResult(); + aggregationResultHolder.setValue(otherSum == null ? sum : sum.add(otherSum)); } } - protected void setAggregationResult(AggregationResultHolder aggregationResultHolder, BigDecimal sum) { - BigDecimal otherSum = aggregationResultHolder.getResult(); - aggregationResultHolder.setValue(otherSum == null ? sum : sum.add(otherSum)); - } - @Override public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder, Map blockValSetMap) { BlockValSet blockValSet = blockValSetMap.get(_expression); - if (_nullHandlingEnabled) { - RoaringBitmap nullBitmap = blockValSet.getNullBitmap(); - if (nullBitmap != null && !nullBitmap.isEmpty()) { - aggregateGroupBySVNullHandlingEnabled(length, groupKeyArray, groupByResultHolder, blockValSet, nullBitmap); - return; - } - } switch (blockValSet.getValueType().getStoredType()) { case INT: int[] intValues = blockValSet.getIntValuesSV(); - for (int i = 0; i < length; i++) { - int groupKey = groupKeyArray[i]; - BigDecimal sum = getDefaultResult(groupByResultHolder, groupKey); - sum = sum.add(BigDecimal.valueOf(intValues[i])); - groupByResultHolder.setValueForKey(groupKey, sum); - } - break; - case LONG: - long[] longValues = blockValSet.getLongValuesSV(); - for (int i = 0; i < length; i++) { - int groupKey = groupKeyArray[i]; - BigDecimal sum = getDefaultResult(groupByResultHolder, groupKey); - sum = sum.add(BigDecimal.valueOf(longValues[i])); - groupByResultHolder.setValueForKey(groupKey, sum); - } - break; - case FLOAT: - case DOUBLE: - case STRING: - String[] stringValues = blockValSet.getStringValuesSV(); - for (int i = 0; i < length; i++) { - int groupKey = groupKeyArray[i]; - BigDecimal sum = getDefaultResult(groupByResultHolder, groupKey); - sum = sum.add(new BigDecimal(stringValues[i])); - groupByResultHolder.setValueForKey(groupKey, sum); - } - break; - case BIG_DECIMAL: - BigDecimal[] bigDecimalValues = blockValSet.getBigDecimalValuesSV(); - for (int i = 0; i < length; i++) { - int groupKey = groupKeyArray[i]; - BigDecimal sum = getDefaultResult(groupByResultHolder, groupKey); - sum = sum.add(bigDecimalValues[i]); - groupByResultHolder.setValueForKey(groupKey, sum); - } - break; - case BYTES: - byte[][] bytesValues = blockValSet.getBytesValuesSV(); - for (int i = 0; i < length; i++) { - int groupKey = groupKeyArray[i]; - BigDecimal sum = getDefaultResult(groupByResultHolder, groupKey); - sum = sum.add(BigDecimalUtils.deserialize(bytesValues[i])); - groupByResultHolder.setValueForKey(groupKey, sum); - } - break; - default: - throw new IllegalStateException(); - } - } - private void aggregateGroupBySVNullHandlingEnabled(int length, int[] groupKeyArray, - GroupByResultHolder groupByResultHolder, BlockValSet blockValSet, RoaringBitmap nullBitmap) { - switch (blockValSet.getValueType().getStoredType()) { - case INT: - if (nullBitmap.getCardinality() < length) { - int[] intValues = blockValSet.getIntValuesSV(); - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { - setGroupByResult(groupKeyArray[i], groupByResultHolder, BigDecimal.valueOf(intValues[i])); - } + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { + updateGroupByResult(groupKeyArray[i], groupByResultHolder, BigDecimal.valueOf(intValues[i])); } - } + }); + break; case LONG: - if (nullBitmap.getCardinality() < length) { - long[] longValues = blockValSet.getLongValuesSV(); - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { - setGroupByResult(groupKeyArray[i], groupByResultHolder, BigDecimal.valueOf(longValues[i])); - } + long[] longValues = blockValSet.getLongValuesSV(); + + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { + updateGroupByResult(groupKeyArray[i], groupByResultHolder, BigDecimal.valueOf(longValues[i])); } - } + }); + break; case FLOAT: case DOUBLE: case STRING: - if (nullBitmap.getCardinality() < length) { - String[] stringValues = blockValSet.getStringValuesSV(); - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { - setGroupByResult(groupKeyArray[i], groupByResultHolder, new BigDecimal(stringValues[i])); - } + String[] stringValues = blockValSet.getStringValuesSV(); + + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { + updateGroupByResult(groupKeyArray[i], groupByResultHolder, new BigDecimal(stringValues[i])); } - } + }); + break; case BIG_DECIMAL: - if (nullBitmap.getCardinality() < length) { - BigDecimal[] bigDecimalValues = blockValSet.getBigDecimalValuesSV(); - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { - setGroupByResult(groupKeyArray[i], groupByResultHolder, bigDecimalValues[i]); - } + BigDecimal[] bigDecimalValues = blockValSet.getBigDecimalValuesSV(); + + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { + updateGroupByResult(groupKeyArray[i], groupByResultHolder, bigDecimalValues[i]); } - } + }); + break; case BYTES: - if (nullBitmap.getCardinality() < length) { - byte[][] bytesValues = blockValSet.getBytesValuesSV(); - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { - setGroupByResult(groupKeyArray[i], groupByResultHolder, BigDecimalUtils.deserialize(bytesValues[i])); - } + byte[][] bytesValues = blockValSet.getBytesValuesSV(); + + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { + updateGroupByResult(groupKeyArray[i], groupByResultHolder, BigDecimalUtils.deserialize(bytesValues[i])); } - } + }); + break; default: throw new IllegalStateException(); } } - private void setGroupByResult(int groupKey, GroupByResultHolder groupByResultHolder, BigDecimal value) { + private void updateGroupByResult(int groupKey, GroupByResultHolder groupByResultHolder, BigDecimal value) { BigDecimal sum = groupByResultHolder.getResult(groupKey); sum = sum == null ? value : sum.add(value); groupByResultHolder.setValueForKey(groupKey, sum); @@ -494,11 +366,6 @@ public BigDecimal mergeFinalResult(BigDecimal finalResult1, BigDecimal finalResu return merge(finalResult1, finalResult2); } - public BigDecimal getDefaultResult(AggregationResultHolder aggregationResultHolder) { - BigDecimal result = aggregationResultHolder.getResult(); - return result != null ? result : BigDecimal.ZERO; - } - public BigDecimal getDefaultResult(GroupByResultHolder groupByResultHolder, int groupKey) { BigDecimal result = groupByResultHolder.getResult(groupKey); return result != null ? result : BigDecimal.ZERO; diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/VarianceAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/VarianceAggregationFunction.java index 6c49dfaa7a0..141035d40ed 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/VarianceAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/VarianceAggregationFunction.java @@ -30,28 +30,25 @@ import org.apache.pinot.core.query.aggregation.utils.StatisticalAggregationFunctionUtils; import org.apache.pinot.segment.local.customobject.VarianceTuple; import org.apache.pinot.segment.spi.AggregationFunctionType; -import org.roaringbitmap.RoaringBitmap; /** * Aggregation function which computes Variance and Standard Deviation * * The algorithm to compute variance is based on "Updating Formulae and a Pairwise Algorithm for Computing - * Sample Variances" by Chan et al. Please refer to the "Parallel Algorithm" section from the following wiki: - * - https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm + * Sample Variances" by Chan et al. Please refer to the "Parallel Algorithm" section from + * -
    this wiki */ -public class VarianceAggregationFunction extends BaseSingleInputAggregationFunction { +public class VarianceAggregationFunction extends NullableSingleInputAggregationFunction { private static final double DEFAULT_FINAL_RESULT = Double.NEGATIVE_INFINITY; protected final boolean _isSample; protected final boolean _isStdDev; - protected final boolean _nullHandlingEnabled; public VarianceAggregationFunction(List arguments, boolean isSample, boolean isStdDev, boolean nullHandlingEnabled) { - super(verifySingleArgument(arguments, getFunctionName(isSample, isStdDev))); + super(verifySingleArgument(arguments, getFunctionName(isSample, isStdDev)), nullHandlingEnabled); _isSample = isSample; _isStdDev = isStdDev; - _nullHandlingEnabled = nullHandlingEnabled; } private static String getFunctionName(boolean isSample, boolean isStdDev) { @@ -80,44 +77,20 @@ public GroupByResultHolder createGroupByResultHolder(int initialCapacity, int ma public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map blockValSetMap) { double[] values = StatisticalAggregationFunctionUtils.getValSet(blockValSetMap, _expression); - RoaringBitmap nullBitmap = null; - if (_nullHandlingEnabled) { - nullBitmap = blockValSetMap.get(_expression).getNullBitmap(); - } - long count = 0; - double sum = 0.0; - double variance = 0.0; - if (nullBitmap != null && !nullBitmap.isEmpty()) { - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { - count++; - sum += values[i]; - if (count > 1) { - variance = computeIntermediateVariance(count, sum, variance, values[i]); - } - } - } - } else { - for (int i = 0; i < length; i++) { - count++; - sum += values[i]; - if (count > 1) { - variance = computeIntermediateVariance(count, sum, variance, values[i]); - } + VarianceTuple varianceTuple = new VarianceTuple(0L, 0.0, 0.0); + + forEachNotNull(length, blockValSetMap.get(_expression), (from, to) -> { + for (int i = from; i < to; i++) { + varianceTuple.apply(values[i]); } - } + }); - if (_nullHandlingEnabled && count == 0) { + if (_nullHandlingEnabled && varianceTuple.getCount() == 0L) { return; } - setAggregationResult(aggregationResultHolder, count, sum, variance); - } - - private double computeIntermediateVariance(long count, double sum, double m2, double value) { - double t = count * value - sum; - m2 += (t * t) / (count * (count - 1)); - return m2; + setAggregationResult(aggregationResultHolder, varianceTuple.getCount(), varianceTuple.getSum(), + varianceTuple.getM2()); } protected void setAggregationResult(AggregationResultHolder aggregationResultHolder, long count, double sum, @@ -144,46 +117,26 @@ protected void setGroupByResult(int groupKey, GroupByResultHolder groupByResultH public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder, Map blockValSetMap) { double[] values = StatisticalAggregationFunctionUtils.getValSet(blockValSetMap, _expression); - RoaringBitmap nullBitmap = null; - if (_nullHandlingEnabled) { - nullBitmap = blockValSetMap.get(_expression).getNullBitmap(); - } - if (nullBitmap != null && !nullBitmap.isEmpty()) { - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { - setGroupByResult(groupKeyArray[i], groupByResultHolder, 1L, values[i], 0.0); - } - } - } else { - for (int i = 0; i < length; i++) { + + forEachNotNull(length, blockValSetMap.get(_expression), (from, to) -> { + for (int i = from; i < to; i++) { setGroupByResult(groupKeyArray[i], groupByResultHolder, 1L, values[i], 0.0); } - } + }); } @Override public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder, Map blockValSetMap) { double[] values = StatisticalAggregationFunctionUtils.getValSet(blockValSetMap, _expression); - RoaringBitmap nullBitmap = null; - if (_nullHandlingEnabled) { - nullBitmap = blockValSetMap.get(_expression).getNullBitmap(); - } - if (nullBitmap != null && !nullBitmap.isEmpty()) { - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { - for (int groupKey : groupKeysArray[i]) { - setGroupByResult(groupKey, groupByResultHolder, 1L, values[i], 0.0); - } - } - } - } else { - for (int i = 0; i < length; i++) { + + forEachNotNull(length, blockValSetMap.get(_expression), (from, to) -> { + for (int i = from; i < to; i++) { for (int groupKey : groupKeysArray[i]) { setGroupByResult(groupKey, groupByResultHolder, 1L, values[i], 0.0); } } - } + }); } @Override diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDistinctDoubleFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDistinctDoubleFunction.java index f5dacc3f4b2..ffcb36b7a4c 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDistinctDoubleFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDistinctDoubleFunction.java @@ -19,11 +19,11 @@ package org.apache.pinot.core.query.aggregation.function.array; import it.unimi.dsi.fastutil.doubles.DoubleOpenHashSet; +import java.util.Map; import org.apache.pinot.common.request.context.ExpressionContext; import org.apache.pinot.core.common.BlockValSet; import org.apache.pinot.core.query.aggregation.AggregationResultHolder; import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder; -import org.roaringbitmap.RoaringBitmap; public class ArrayAggDistinctDoubleFunction extends BaseArrayAggDoubleFunction { @@ -32,25 +32,17 @@ public ArrayAggDistinctDoubleFunction(ExpressionContext expression, boolean null } @Override - protected void aggregateArray(int length, AggregationResultHolder aggregationResultHolder, BlockValSet blockValSet) { - DoubleOpenHashSet valueArray = new DoubleOpenHashSet(length); + public void aggregate(int length, AggregationResultHolder aggregationResultHolder, + Map blockValSetMap) { + BlockValSet blockValSet = blockValSetMap.get(_expression); double[] value = blockValSet.getDoubleValuesSV(); - for (int i = 0; i < length; i++) { - valueArray.add(value[i]); - } - aggregationResultHolder.setValue(valueArray); - } - - @Override - protected void aggregateArrayWithNull(int length, AggregationResultHolder aggregationResultHolder, - BlockValSet blockValSet, RoaringBitmap nullBitmap) { DoubleOpenHashSet valueArray = new DoubleOpenHashSet(length); - double[] value = blockValSet.getDoubleValuesSV(); - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { + + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { valueArray.add(value[i]); } - } + }); aggregationResultHolder.setValue(valueArray); } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDistinctFloatFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDistinctFloatFunction.java index cf3e21eeb12..e1c6b335b57 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDistinctFloatFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDistinctFloatFunction.java @@ -19,11 +19,11 @@ package org.apache.pinot.core.query.aggregation.function.array; import it.unimi.dsi.fastutil.floats.FloatOpenHashSet; +import java.util.Map; import org.apache.pinot.common.request.context.ExpressionContext; import org.apache.pinot.core.common.BlockValSet; import org.apache.pinot.core.query.aggregation.AggregationResultHolder; import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder; -import org.roaringbitmap.RoaringBitmap; public class ArrayAggDistinctFloatFunction extends BaseArrayAggFloatFunction { @@ -32,25 +32,17 @@ public ArrayAggDistinctFloatFunction(ExpressionContext expression, boolean nullH } @Override - protected void aggregateArray(int length, AggregationResultHolder aggregationResultHolder, BlockValSet blockValSet) { - FloatOpenHashSet valueArray = new FloatOpenHashSet(length); + public void aggregate(int length, AggregationResultHolder aggregationResultHolder, + Map blockValSetMap) { + BlockValSet blockValSet = blockValSetMap.get(_expression); float[] value = blockValSet.getFloatValuesSV(); - for (int i = 0; i < length; i++) { - valueArray.add(value[i]); - } - aggregationResultHolder.setValue(valueArray); - } - - @Override - protected void aggregateArrayWithNull(int length, AggregationResultHolder aggregationResultHolder, - BlockValSet blockValSet, RoaringBitmap nullBitmap) { FloatOpenHashSet valueArray = new FloatOpenHashSet(length); - float[] value = blockValSet.getFloatValuesSV(); - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { + + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { valueArray.add(value[i]); } - } + }); aggregationResultHolder.setValue(valueArray); } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDistinctIntFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDistinctIntFunction.java index 1ccbb27b0ad..edb8c2e646f 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDistinctIntFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDistinctIntFunction.java @@ -19,12 +19,12 @@ package org.apache.pinot.core.query.aggregation.function.array; import it.unimi.dsi.fastutil.ints.IntOpenHashSet; +import java.util.Map; import org.apache.pinot.common.request.context.ExpressionContext; import org.apache.pinot.core.common.BlockValSet; import org.apache.pinot.core.query.aggregation.AggregationResultHolder; import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder; import org.apache.pinot.spi.data.FieldSpec; -import org.roaringbitmap.RoaringBitmap; public class ArrayAggDistinctIntFunction extends BaseArrayAggIntFunction { @@ -34,26 +34,17 @@ public ArrayAggDistinctIntFunction(ExpressionContext expression, FieldSpec.DataT } @Override - protected void aggregateArray(int length, AggregationResultHolder aggregationResultHolder, - BlockValSet blockValSet) { + public void aggregate(int length, AggregationResultHolder aggregationResultHolder, + Map blockValSetMap) { + BlockValSet blockValSet = blockValSetMap.get(_expression); int[] value = blockValSet.getIntValuesSV(); IntOpenHashSet valueArray = new IntOpenHashSet(length); - for (int i = 0; i < length; i++) { - valueArray.add(value[i]); - } - aggregationResultHolder.setValue(valueArray); - } - @Override - protected void aggregateArrayWithNull(int length, AggregationResultHolder aggregationResultHolder, - BlockValSet blockValSet, RoaringBitmap nullBitmap) { - int[] value = blockValSet.getIntValuesSV(); - IntOpenHashSet valueArray = new IntOpenHashSet(length); - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { valueArray.add(value[i]); } - } + }); aggregationResultHolder.setValue(valueArray); } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDistinctLongFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDistinctLongFunction.java index bc35b9cb360..31c145557bb 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDistinctLongFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDistinctLongFunction.java @@ -19,12 +19,12 @@ package org.apache.pinot.core.query.aggregation.function.array; import it.unimi.dsi.fastutil.longs.LongOpenHashSet; +import java.util.Map; import org.apache.pinot.common.request.context.ExpressionContext; import org.apache.pinot.core.common.BlockValSet; import org.apache.pinot.core.query.aggregation.AggregationResultHolder; import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder; import org.apache.pinot.spi.data.FieldSpec; -import org.roaringbitmap.RoaringBitmap; public class ArrayAggDistinctLongFunction extends BaseArrayAggLongFunction { @@ -34,26 +34,17 @@ public ArrayAggDistinctLongFunction(ExpressionContext expression, FieldSpec.Data } @Override - protected void aggregateArray(int length, AggregationResultHolder aggregationResultHolder, - BlockValSet blockValSet) { + public void aggregate(int length, AggregationResultHolder aggregationResultHolder, + Map blockValSetMap) { + BlockValSet blockValSet = blockValSetMap.get(_expression); long[] value = blockValSet.getLongValuesSV(); LongOpenHashSet valueArray = new LongOpenHashSet(length); - for (int i = 0; i < length; i++) { - valueArray.add(value[i]); - } - aggregationResultHolder.setValue(valueArray); - } - @Override - protected void aggregateArrayWithNull(int length, AggregationResultHolder aggregationResultHolder, - BlockValSet blockValSet, RoaringBitmap nullBitmap) { - long[] value = blockValSet.getLongValuesSV(); - LongOpenHashSet valueArray = new LongOpenHashSet(length); - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { valueArray.add(value[i]); } - } + }); aggregationResultHolder.setValue(valueArray); } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDistinctStringFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDistinctStringFunction.java index 0a964e3f10b..6b6c43370ba 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDistinctStringFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDistinctStringFunction.java @@ -20,11 +20,11 @@ import it.unimi.dsi.fastutil.objects.ObjectOpenHashSet; import java.util.Arrays; +import java.util.Map; import org.apache.pinot.common.request.context.ExpressionContext; import org.apache.pinot.core.common.BlockValSet; import org.apache.pinot.core.query.aggregation.AggregationResultHolder; import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder; -import org.roaringbitmap.RoaringBitmap; public class ArrayAggDistinctStringFunction extends BaseArrayAggStringFunction> { @@ -33,23 +33,12 @@ public ArrayAggDistinctStringFunction(ExpressionContext expression, boolean null } @Override - protected void aggregateArray(int length, AggregationResultHolder aggregationResultHolder, BlockValSet blockValSet) { - ObjectOpenHashSet valueArray = new ObjectOpenHashSet<>(length); + public void aggregate(int length, AggregationResultHolder aggregationResultHolder, + Map blockValSetMap) { + BlockValSet blockValSet = blockValSetMap.get(_expression); String[] value = blockValSet.getStringValuesSV(); - valueArray.addAll(Arrays.asList(value)); - aggregationResultHolder.setValue(valueArray); - } - - @Override - protected void aggregateArrayWithNull(int length, AggregationResultHolder aggregationResultHolder, - BlockValSet blockValSet, RoaringBitmap nullBitmap) { ObjectOpenHashSet valueArray = new ObjectOpenHashSet<>(length); - String[] value = blockValSet.getStringValuesSV(); - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { - valueArray.add(value[i]); - } - } + forEachNotNull(length, blockValSet, (from, to) -> valueArray.addAll(Arrays.asList(value).subList(from, to))); aggregationResultHolder.setValue(valueArray); } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDoubleFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDoubleFunction.java index 78a29697bee..05c95f41616 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDoubleFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDoubleFunction.java @@ -19,11 +19,11 @@ package org.apache.pinot.core.query.aggregation.function.array; import it.unimi.dsi.fastutil.doubles.DoubleArrayList; +import java.util.Map; import org.apache.pinot.common.request.context.ExpressionContext; import org.apache.pinot.core.common.BlockValSet; import org.apache.pinot.core.query.aggregation.AggregationResultHolder; import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder; -import org.roaringbitmap.RoaringBitmap; public class ArrayAggDoubleFunction extends BaseArrayAggDoubleFunction { @@ -32,25 +32,17 @@ public ArrayAggDoubleFunction(ExpressionContext expression, boolean nullHandling } @Override - protected void aggregateArray(int length, AggregationResultHolder aggregationResultHolder, BlockValSet blockValSet) { - DoubleArrayList valueArray = new DoubleArrayList(length); + public void aggregate(int length, AggregationResultHolder aggregationResultHolder, + Map blockValSetMap) { + BlockValSet blockValSet = blockValSetMap.get(_expression); double[] value = blockValSet.getDoubleValuesSV(); - for (int i = 0; i < length; i++) { - valueArray.add(value[i]); - } - aggregationResultHolder.setValue(valueArray); - } - - @Override - protected void aggregateArrayWithNull(int length, AggregationResultHolder aggregationResultHolder, - BlockValSet blockValSet, RoaringBitmap nullBitmap) { DoubleArrayList valueArray = new DoubleArrayList(length); - double[] value = blockValSet.getDoubleValuesSV(); - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { + + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { valueArray.add(value[i]); } - } + }); aggregationResultHolder.setValue(valueArray); } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggFloatFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggFloatFunction.java index 50cb800801a..6aa893737f3 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggFloatFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggFloatFunction.java @@ -19,11 +19,11 @@ package org.apache.pinot.core.query.aggregation.function.array; import it.unimi.dsi.fastutil.floats.FloatArrayList; +import java.util.Map; import org.apache.pinot.common.request.context.ExpressionContext; import org.apache.pinot.core.common.BlockValSet; import org.apache.pinot.core.query.aggregation.AggregationResultHolder; import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder; -import org.roaringbitmap.RoaringBitmap; public class ArrayAggFloatFunction extends BaseArrayAggFloatFunction { @@ -32,30 +32,21 @@ public ArrayAggFloatFunction(ExpressionContext expression, boolean nullHandlingE } @Override - protected void aggregateArray(int length, AggregationResultHolder aggregationResultHolder, BlockValSet blockValSet) { - FloatArrayList valueArray = new FloatArrayList(length); + public void aggregate(int length, AggregationResultHolder aggregationResultHolder, + Map blockValSetMap) { + BlockValSet blockValSet = blockValSetMap.get(_expression); float[] value = blockValSet.getFloatValuesSV(); - for (int i = 0; i < length; i++) { - valueArray.add(value[i]); - } - aggregationResultHolder.setValue(valueArray); - } - - @Override - protected void aggregateArrayWithNull(int length, AggregationResultHolder aggregationResultHolder, - BlockValSet blockValSet, RoaringBitmap nullBitmap) { FloatArrayList valueArray = new FloatArrayList(length); - float[] value = blockValSet.getFloatValuesSV(); - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { + + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { valueArray.add(value[i]); } - } + }); aggregationResultHolder.setValue(valueArray); } @Override - protected void setGroupByResult(GroupByResultHolder resultHolder, int groupKey, float value) { FloatArrayList valueArray = resultHolder.getResult(groupKey); if (valueArray == null) { diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggIntFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggIntFunction.java index 45589ab739b..d2194b1af55 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggIntFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggIntFunction.java @@ -19,12 +19,12 @@ package org.apache.pinot.core.query.aggregation.function.array; import it.unimi.dsi.fastutil.ints.IntArrayList; +import java.util.Map; import org.apache.pinot.common.request.context.ExpressionContext; import org.apache.pinot.core.common.BlockValSet; import org.apache.pinot.core.query.aggregation.AggregationResultHolder; import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder; import org.apache.pinot.spi.data.FieldSpec; -import org.roaringbitmap.RoaringBitmap; public class ArrayAggIntFunction extends BaseArrayAggIntFunction { @@ -33,29 +33,21 @@ public ArrayAggIntFunction(ExpressionContext expression, FieldSpec.DataType data } @Override - protected void aggregateArray(int length, AggregationResultHolder aggregationResultHolder, - BlockValSet blockValSet) { + public void aggregate(int length, AggregationResultHolder aggregationResultHolder, + Map blockValSetMap) { + BlockValSet blockValSet = blockValSetMap.get(_expression); int[] value = blockValSet.getIntValuesSV(); IntArrayList valueArray = new IntArrayList(length); - for (int i = 0; i < length; i++) { - valueArray.add(value[i]); - } - aggregationResultHolder.setValue(valueArray); - } - @Override - protected void aggregateArrayWithNull(int length, AggregationResultHolder aggregationResultHolder, - BlockValSet blockValSet, RoaringBitmap nullBitmap) { - int[] value = blockValSet.getIntValuesSV(); - IntArrayList valueArray = new IntArrayList(length); - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { valueArray.add(value[i]); } - } + }); aggregationResultHolder.setValue(valueArray); } + @Override protected void setGroupByResult(GroupByResultHolder resultHolder, int groupKey, int value) { IntArrayList valueArray = resultHolder.getResult(groupKey); if (valueArray == null) { diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggLongFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggLongFunction.java index 61643a5fda0..d41ddc0dcbf 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggLongFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggLongFunction.java @@ -19,12 +19,12 @@ package org.apache.pinot.core.query.aggregation.function.array; import it.unimi.dsi.fastutil.longs.LongArrayList; +import java.util.Map; import org.apache.pinot.common.request.context.ExpressionContext; import org.apache.pinot.core.common.BlockValSet; import org.apache.pinot.core.query.aggregation.AggregationResultHolder; import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder; import org.apache.pinot.spi.data.FieldSpec; -import org.roaringbitmap.RoaringBitmap; public class ArrayAggLongFunction extends BaseArrayAggLongFunction { @@ -33,26 +33,17 @@ public ArrayAggLongFunction(ExpressionContext expression, FieldSpec.DataType dat } @Override - protected void aggregateArray(int length, AggregationResultHolder aggregationResultHolder, - BlockValSet blockValSet) { + public void aggregate(int length, AggregationResultHolder aggregationResultHolder, + Map blockValSetMap) { + BlockValSet blockValSet = blockValSetMap.get(_expression); long[] value = blockValSet.getLongValuesSV(); LongArrayList valueArray = new LongArrayList(length); - for (int i = 0; i < length; i++) { - valueArray.add(value[i]); - } - aggregationResultHolder.setValue(valueArray); - } - @Override - protected void aggregateArrayWithNull(int length, AggregationResultHolder aggregationResultHolder, - BlockValSet blockValSet, RoaringBitmap nullBitmap) { - long[] value = blockValSet.getLongValuesSV(); - LongArrayList valueArray = new LongArrayList(length); - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { valueArray.add(value[i]); } - } + }); aggregationResultHolder.setValue(valueArray); } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggStringFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggStringFunction.java index 5abfe51dd29..5064dee12a4 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggStringFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggStringFunction.java @@ -20,11 +20,11 @@ import it.unimi.dsi.fastutil.objects.ObjectArrayList; import java.util.Arrays; +import java.util.Map; import org.apache.pinot.common.request.context.ExpressionContext; import org.apache.pinot.core.common.BlockValSet; import org.apache.pinot.core.query.aggregation.AggregationResultHolder; import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder; -import org.roaringbitmap.RoaringBitmap; public class ArrayAggStringFunction extends BaseArrayAggStringFunction> { @@ -33,23 +33,12 @@ public ArrayAggStringFunction(ExpressionContext expression, boolean nullHandling } @Override - protected void aggregateArray(int length, AggregationResultHolder aggregationResultHolder, BlockValSet blockValSet) { - ObjectArrayList valueArray = new ObjectArrayList<>(length); + public void aggregate(int length, AggregationResultHolder aggregationResultHolder, + Map blockValSetMap) { + BlockValSet blockValSet = blockValSetMap.get(_expression); String[] value = blockValSet.getStringValuesSV(); - valueArray.addAll(Arrays.asList(value)); - aggregationResultHolder.setValue(valueArray); - } - - @Override - protected void aggregateArrayWithNull(int length, AggregationResultHolder aggregationResultHolder, - BlockValSet blockValSet, RoaringBitmap nullBitmap) { ObjectArrayList valueArray = new ObjectArrayList<>(length); - String[] value = blockValSet.getStringValuesSV(); - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { - valueArray.add(value[i]); - } - } + forEachNotNull(length, blockValSet, (from, to) -> valueArray.addAll(Arrays.asList(value).subList(from, to))); aggregationResultHolder.setValue(valueArray); } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/BaseArrayAggDoubleFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/BaseArrayAggDoubleFunction.java index de3e52b80e5..c4d789c8261 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/BaseArrayAggDoubleFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/BaseArrayAggDoubleFunction.java @@ -20,11 +20,11 @@ import it.unimi.dsi.fastutil.doubles.AbstractDoubleCollection; import it.unimi.dsi.fastutil.doubles.DoubleArrayList; +import java.util.Map; import org.apache.pinot.common.request.context.ExpressionContext; import org.apache.pinot.core.common.BlockValSet; import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder; import org.apache.pinot.spi.data.FieldSpec; -import org.roaringbitmap.RoaringBitmap; public abstract class BaseArrayAggDoubleFunction @@ -36,47 +36,32 @@ public BaseArrayAggDoubleFunction(ExpressionContext expression, boolean nullHand abstract void setGroupByResult(GroupByResultHolder groupByResultHolder, int groupKey, double value); @Override - protected void aggregateArrayGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder, - BlockValSet blockValSet) { + public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder, + Map blockValSetMap) { + BlockValSet blockValSet = blockValSetMap.get(_expression); double[] values = blockValSet.getDoubleValuesSV(); - for (int i = 0; i < length; i++) { - setGroupByResult(groupByResultHolder, groupKeyArray[i], values[i]); - } - } - @Override - protected void aggregateArrayGroupBySVWithNull(int length, int[] groupKeyArray, - GroupByResultHolder groupByResultHolder, BlockValSet blockValSet, RoaringBitmap nullBitmap) { - double[] values = blockValSet.getDoubleValuesSV(); - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { setGroupByResult(groupByResultHolder, groupKeyArray[i], values[i]); } - } + }); } @Override - protected void aggregateArrayGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder, - BlockValSet blockValSet) { + public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder, + Map blockValSetMap) { + BlockValSet blockValSet = blockValSetMap.get(_expression); double[] values = blockValSet.getDoubleValuesSV(); - for (int i = 0; i < length; i++) { - for (int groupKey : groupKeysArray[i]) { - setGroupByResult(groupByResultHolder, groupKey, values[i]); - } - } - } - @Override - protected void aggregateArrayGroupByMVWithNull(int length, int[][] groupKeysArray, - GroupByResultHolder groupByResultHolder, BlockValSet blockValSet, RoaringBitmap nullBitmap) { - double[] values = blockValSet.getDoubleValuesSV(); - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { - for (int groupKey : groupKeysArray[i]) { + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { + int[] groupKeys = groupKeysArray[i]; + for (int groupKey : groupKeys) { setGroupByResult(groupByResultHolder, groupKey, values[i]); } } - } + }); } @Override diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/BaseArrayAggFloatFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/BaseArrayAggFloatFunction.java index a6f4078dc54..776e87bac9d 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/BaseArrayAggFloatFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/BaseArrayAggFloatFunction.java @@ -20,11 +20,11 @@ import it.unimi.dsi.fastutil.floats.AbstractFloatCollection; import it.unimi.dsi.fastutil.floats.FloatArrayList; +import java.util.Map; import org.apache.pinot.common.request.context.ExpressionContext; import org.apache.pinot.core.common.BlockValSet; import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder; import org.apache.pinot.spi.data.FieldSpec; -import org.roaringbitmap.RoaringBitmap; public abstract class BaseArrayAggFloatFunction @@ -36,47 +36,32 @@ public BaseArrayAggFloatFunction(ExpressionContext expression, boolean nullHandl abstract void setGroupByResult(GroupByResultHolder groupByResultHolder, int groupKey, float value); @Override - protected void aggregateArrayGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder, - BlockValSet blockValSet) { + public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder, + Map blockValSetMap) { + BlockValSet blockValSet = blockValSetMap.get(_expression); float[] values = blockValSet.getFloatValuesSV(); - for (int i = 0; i < length; i++) { - setGroupByResult(groupByResultHolder, groupKeyArray[i], values[i]); - } - } - @Override - protected void aggregateArrayGroupBySVWithNull(int length, int[] groupKeyArray, - GroupByResultHolder groupByResultHolder, BlockValSet blockValSet, RoaringBitmap nullBitmap) { - float[] values = blockValSet.getFloatValuesSV(); - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { setGroupByResult(groupByResultHolder, groupKeyArray[i], values[i]); } - } + }); } @Override - protected void aggregateArrayGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder, - BlockValSet blockValSet) { + public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder, + Map blockValSetMap) { + BlockValSet blockValSet = blockValSetMap.get(_expression); float[] values = blockValSet.getFloatValuesSV(); - for (int i = 0; i < length; i++) { - for (int groupKey : groupKeysArray[i]) { - setGroupByResult(groupByResultHolder, groupKey, values[i]); - } - } - } - @Override - protected void aggregateArrayGroupByMVWithNull(int length, int[][] groupKeysArray, - GroupByResultHolder groupByResultHolder, BlockValSet blockValSet, RoaringBitmap nullBitmap) { - float[] values = blockValSet.getFloatValuesSV(); - for (int i = 0; i < length; i++) { - for (int groupKey : groupKeysArray[i]) { - if (!nullBitmap.contains(i)) { + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { + int[] groupKeys = groupKeysArray[i]; + for (int groupKey : groupKeys) { setGroupByResult(groupByResultHolder, groupKey, values[i]); } } - } + }); } @Override diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/BaseArrayAggFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/BaseArrayAggFunction.java index 5f17c16a197..9199abdca8b 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/BaseArrayAggFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/BaseArrayAggFunction.java @@ -18,28 +18,24 @@ */ package org.apache.pinot.core.query.aggregation.function.array; -import java.util.Map; import org.apache.pinot.common.request.context.ExpressionContext; import org.apache.pinot.common.utils.DataSchema; -import org.apache.pinot.core.common.BlockValSet; import org.apache.pinot.core.query.aggregation.AggregationResultHolder; import org.apache.pinot.core.query.aggregation.ObjectAggregationResultHolder; -import org.apache.pinot.core.query.aggregation.function.BaseSingleInputAggregationFunction; +import org.apache.pinot.core.query.aggregation.function.NullableSingleInputAggregationFunction; import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder; import org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder; import org.apache.pinot.segment.spi.AggregationFunctionType; import org.apache.pinot.spi.data.FieldSpec; -import org.roaringbitmap.RoaringBitmap; -public abstract class BaseArrayAggFunction extends BaseSingleInputAggregationFunction { +public abstract class BaseArrayAggFunction + extends NullableSingleInputAggregationFunction { - protected final boolean _nullHandlingEnabled; private final DataSchema.ColumnDataType _resultColumnType; public BaseArrayAggFunction(ExpressionContext expression, FieldSpec.DataType dataType, boolean nullHandlingEnabled) { - super(expression); - _nullHandlingEnabled = nullHandlingEnabled; + super(expression, nullHandlingEnabled); _resultColumnType = DataSchema.ColumnDataType.fromDataTypeMV(dataType); } @@ -68,66 +64,6 @@ public DataSchema.ColumnDataType getFinalResultColumnType() { return _resultColumnType; } - @Override - public void aggregate(int length, AggregationResultHolder aggregationResultHolder, - Map blockValSetMap) { - BlockValSet blockValSet = blockValSetMap.get(_expression); - if (_nullHandlingEnabled) { - RoaringBitmap nullBitmap = blockValSet.getNullBitmap(); - if (nullBitmap != null && !nullBitmap.isEmpty()) { - aggregateArrayWithNull(length, aggregationResultHolder, blockValSet, nullBitmap); - return; - } - } - aggregateArray(length, aggregationResultHolder, blockValSet); - } - - protected abstract void aggregateArray(int length, AggregationResultHolder aggregationResultHolder, - BlockValSet blockValSet); - - protected abstract void aggregateArrayWithNull(int length, AggregationResultHolder aggregationResultHolder, - BlockValSet blockValSet, RoaringBitmap nullBitmap); - - @Override - public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder, - Map blockValSetMap) { - BlockValSet blockValSet = blockValSetMap.get(_expression); - if (_nullHandlingEnabled) { - RoaringBitmap nullBitmap = blockValSet.getNullBitmap(); - if (nullBitmap != null && !nullBitmap.isEmpty()) { - aggregateArrayGroupBySVWithNull(length, groupKeyArray, groupByResultHolder, blockValSet, nullBitmap); - return; - } - } - aggregateArrayGroupBySV(length, groupKeyArray, groupByResultHolder, blockValSet); - } - - protected abstract void aggregateArrayGroupBySV(int length, int[] groupKeyArray, - GroupByResultHolder groupByResultHolder, BlockValSet blockValSet); - - protected abstract void aggregateArrayGroupBySVWithNull(int length, int[] groupKeyArray, - GroupByResultHolder groupByResultHolder, BlockValSet blockValSet, RoaringBitmap nullBitmap); - - @Override - public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder, - Map blockValSetMap) { - BlockValSet blockValSet = blockValSetMap.get(_expression); - if (_nullHandlingEnabled) { - RoaringBitmap nullBitmap = blockValSet.getNullBitmap(); - if (nullBitmap != null && !nullBitmap.isEmpty()) { - aggregateArrayGroupByMVWithNull(length, groupKeysArray, groupByResultHolder, blockValSet, nullBitmap); - return; - } - } - aggregateArrayGroupByMV(length, groupKeysArray, groupByResultHolder, blockValSet); - } - - protected abstract void aggregateArrayGroupByMV(int length, int[][] groupKeysArray, - GroupByResultHolder groupByResultHolder, BlockValSet blockValSet); - - protected abstract void aggregateArrayGroupByMVWithNull(int length, int[][] groupKeysArray, - GroupByResultHolder groupByResultHolder, BlockValSet blockValSet, RoaringBitmap nullBitmap); - @Override public I extractAggregationResult(AggregationResultHolder aggregationResultHolder) { return aggregationResultHolder.getResult(); diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/BaseArrayAggIntFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/BaseArrayAggIntFunction.java index 7f05d158578..24449ea97fd 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/BaseArrayAggIntFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/BaseArrayAggIntFunction.java @@ -20,11 +20,11 @@ import it.unimi.dsi.fastutil.ints.AbstractIntCollection; import it.unimi.dsi.fastutil.ints.IntArrayList; +import java.util.Map; import org.apache.pinot.common.request.context.ExpressionContext; import org.apache.pinot.core.common.BlockValSet; import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder; import org.apache.pinot.spi.data.FieldSpec; -import org.roaringbitmap.RoaringBitmap; public abstract class BaseArrayAggIntFunction @@ -37,51 +37,32 @@ public BaseArrayAggIntFunction(ExpressionContext expression, FieldSpec.DataType abstract void setGroupByResult(GroupByResultHolder groupByResultHolder, int groupKey, int value); @Override - protected void aggregateArrayGroupBySV(int length, int[] groupKeyArray, - GroupByResultHolder groupByResultHolder, BlockValSet blockValSet) { + public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder, + Map blockValSetMap) { + BlockValSet blockValSet = blockValSetMap.get(_expression); int[] values = blockValSet.getIntValuesSV(); - for (int i = 0; i < length; i++) { - setGroupByResult(groupByResultHolder, groupKeyArray[i], values[i]); - } - } - @Override - protected void aggregateArrayGroupBySVWithNull(int length, int[] groupKeyArray, - GroupByResultHolder groupByResultHolder, BlockValSet blockValSet, RoaringBitmap nullBitmap) { - int[] values = blockValSet.getIntValuesSV(); - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { setGroupByResult(groupByResultHolder, groupKeyArray[i], values[i]); } - } + }); } @Override - protected void aggregateArrayGroupByMV(int length, int[][] groupKeysArray, - GroupByResultHolder groupByResultHolder, BlockValSet blockValSet) { + public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder, + Map blockValSetMap) { + BlockValSet blockValSet = blockValSetMap.get(_expression); int[] values = blockValSet.getIntValuesSV(); - for (int i = 0; i < length; i++) { - int[] groupKeys = groupKeysArray[i]; - int value = values[i]; - for (int groupKey : groupKeys) { - setGroupByResult(groupByResultHolder, groupKey, value); - } - } - } - @Override - protected void aggregateArrayGroupByMVWithNull(int length, int[][] groupKeysArray, - GroupByResultHolder groupByResultHolder, BlockValSet blockValSet, RoaringBitmap nullBitmap) { - int[] values = blockValSet.getIntValuesSV(); - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { int[] groupKeys = groupKeysArray[i]; - int value = values[i]; for (int groupKey : groupKeys) { - setGroupByResult(groupByResultHolder, groupKey, value); + setGroupByResult(groupByResultHolder, groupKey, values[i]); } } - } + }); } @Override diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/BaseArrayAggLongFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/BaseArrayAggLongFunction.java index 76e705945af..7348216bc69 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/BaseArrayAggLongFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/BaseArrayAggLongFunction.java @@ -20,11 +20,11 @@ import it.unimi.dsi.fastutil.longs.AbstractLongCollection; import it.unimi.dsi.fastutil.longs.LongArrayList; +import java.util.Map; import org.apache.pinot.common.request.context.ExpressionContext; import org.apache.pinot.core.common.BlockValSet; import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder; import org.apache.pinot.spi.data.FieldSpec; -import org.roaringbitmap.RoaringBitmap; public abstract class BaseArrayAggLongFunction @@ -37,49 +37,32 @@ public BaseArrayAggLongFunction(ExpressionContext expression, FieldSpec.DataType abstract void setGroupByResult(GroupByResultHolder groupByResultHolder, int groupKey, long value); @Override - protected void aggregateArrayGroupBySV(int length, int[] groupKeyArray, - GroupByResultHolder groupByResultHolder, BlockValSet blockValSet) { + public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder, + Map blockValSetMap) { + BlockValSet blockValSet = blockValSetMap.get(_expression); long[] values = blockValSet.getLongValuesSV(); - for (int i = 0; i < length; i++) { - setGroupByResult(groupByResultHolder, groupKeyArray[i], values[i]); - } - } - @Override - protected void aggregateArrayGroupBySVWithNull(int length, int[] groupKeyArray, - GroupByResultHolder groupByResultHolder, BlockValSet blockValSet, RoaringBitmap nullBitmap) { - long[] values = blockValSet.getLongValuesSV(); - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { setGroupByResult(groupByResultHolder, groupKeyArray[i], values[i]); } - } + }); } @Override - protected void aggregateArrayGroupByMV(int length, int[][] groupKeysArray, - GroupByResultHolder groupByResultHolder, BlockValSet blockValSet) { + public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder, + Map blockValSetMap) { + BlockValSet blockValSet = blockValSetMap.get(_expression); long[] values = blockValSet.getLongValuesSV(); - for (int i = 0; i < length; i++) { - int[] groupKeys = groupKeysArray[i]; - for (int groupKey : groupKeys) { - setGroupByResult(groupByResultHolder, groupKey, values[i]); - } - } - } - @Override - protected void aggregateArrayGroupByMVWithNull(int length, int[][] groupKeysArray, - GroupByResultHolder groupByResultHolder, BlockValSet blockValSet, RoaringBitmap nullBitmap) { - long[] values = blockValSet.getLongValuesSV(); - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { int[] groupKeys = groupKeysArray[i]; for (int groupKey : groupKeys) { setGroupByResult(groupByResultHolder, groupKey, values[i]); } } - } + }); } @Override diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/BaseArrayAggStringFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/BaseArrayAggStringFunction.java index 2e80eb6040a..e1f1aedacf4 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/BaseArrayAggStringFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/BaseArrayAggStringFunction.java @@ -21,11 +21,11 @@ import it.unimi.dsi.fastutil.objects.AbstractObjectCollection; import it.unimi.dsi.fastutil.objects.ObjectArrayList; import it.unimi.dsi.fastutil.objects.ObjectIterators; +import java.util.Map; import org.apache.pinot.common.request.context.ExpressionContext; import org.apache.pinot.core.common.BlockValSet; import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder; import org.apache.pinot.spi.data.FieldSpec; -import org.roaringbitmap.RoaringBitmap; public abstract class BaseArrayAggStringFunction> @@ -37,47 +37,32 @@ public BaseArrayAggStringFunction(ExpressionContext expression, boolean nullHand abstract void setGroupByResult(GroupByResultHolder groupByResultHolder, int groupKey, String value); @Override - protected void aggregateArrayGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder, - BlockValSet blockValSet) { + public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder, + Map blockValSetMap) { + BlockValSet blockValSet = blockValSetMap.get(_expression); String[] values = blockValSet.getStringValuesSV(); - for (int i = 0; i < length; i++) { - setGroupByResult(groupByResultHolder, groupKeyArray[i], values[i]); - } - } - @Override - protected void aggregateArrayGroupBySVWithNull(int length, int[] groupKeyArray, - GroupByResultHolder groupByResultHolder, BlockValSet blockValSet, RoaringBitmap nullBitmap) { - String[] values = blockValSet.getStringValuesSV(); - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { setGroupByResult(groupByResultHolder, groupKeyArray[i], values[i]); } - } + }); } @Override - protected void aggregateArrayGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder, - BlockValSet blockValSet) { + public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder, + Map blockValSetMap) { + BlockValSet blockValSet = blockValSetMap.get(_expression); String[] values = blockValSet.getStringValuesSV(); - for (int i = 0; i < length; i++) { - for (int groupKey : groupKeysArray[i]) { - setGroupByResult(groupByResultHolder, groupKey, values[i]); - } - } - } - @Override - protected void aggregateArrayGroupByMVWithNull(int length, int[][] groupKeysArray, - GroupByResultHolder groupByResultHolder, BlockValSet blockValSet, RoaringBitmap nullBitmap) { - String[] values = blockValSet.getStringValuesSV(); - for (int i = 0; i < length; i++) { - if (!nullBitmap.contains(i)) { - for (int groupKey : groupKeysArray[i]) { + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { + int[] groupKeys = groupKeysArray[i]; + for (int groupKey : groupKeys) { setGroupByResult(groupByResultHolder, groupKey, values[i]); } } - } + }); } @Override diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/utils/StatisticalAggregationFunctionUtils.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/utils/StatisticalAggregationFunctionUtils.java index b7a05de3a41..988d8348279 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/utils/StatisticalAggregationFunctionUtils.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/utils/StatisticalAggregationFunctionUtils.java @@ -19,9 +19,11 @@ package org.apache.pinot.core.query.aggregation.utils; import com.google.common.base.Preconditions; +import java.util.List; import java.util.Map; import org.apache.pinot.common.request.context.ExpressionContext; import org.apache.pinot.core.common.BlockValSet; +import org.apache.pinot.segment.spi.AggregationFunctionType; /** @@ -50,4 +52,39 @@ public static double[] getValSet(Map blockValSet + blockValSet.getValueType()); } } + + public static Double calculateVariance(List values, AggregationFunctionType aggregationFunctionType) { + long count = 0; + double sum = 0; + double variance = 0; + + for (Double value : values) { + count++; + sum += value; + if (count > 1) { + variance = computeIntermediateVariance(count, sum, variance, value); + } + } + + assert count > 1; + + switch (aggregationFunctionType) { + case VARPOP: + return variance / count; + case VARSAMP: + return variance / (count - 1); + case STDDEVPOP: + return Math.sqrt(variance / count); + case STDDEVSAMP: + return Math.sqrt(variance / (count - 1)); + default: + throw new IllegalArgumentException(); + } + } + + public static double computeIntermediateVariance(long count, double sum, double m2, double value) { + double t = count * value - sum; + m2 += (t * t) / (count * (count - 1)); + return m2; + } } diff --git a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/AbstractAggregationFunctionTest.java b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/AbstractAggregationFunctionTest.java index 122374d224c..0e90fe98936 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/AbstractAggregationFunctionTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/AbstractAggregationFunctionTest.java @@ -57,13 +57,31 @@ public abstract class AbstractAggregationFunctionTest { FieldSpec.DataType.BOOLEAN }; - protected static final Map SINGLE_FIELD_NULLABLE_SCHEMAS = Arrays.stream(VALID_DATA_TYPES) + private static final FieldSpec.DataType[] VALID_METRIC_DATA_TYPES = new FieldSpec.DataType[] { + FieldSpec.DataType.INT, + FieldSpec.DataType.LONG, + FieldSpec.DataType.FLOAT, + FieldSpec.DataType.DOUBLE, + FieldSpec.DataType.BIG_DECIMAL, + FieldSpec.DataType.BYTES + }; + + protected static final Map SINGLE_FIELD_NULLABLE_DIMENSION_SCHEMAS = + Arrays.stream(VALID_DATA_TYPES) .collect(Collectors.toMap(dt -> dt, dt -> new Schema.SchemaBuilder() .setSchemaName("testTable") .setEnableColumnBasedNullHandling(true) .addDimensionField("myField", dt, f -> f.setNullable(true)) .build())); + protected static final Map SINGLE_FIELD_NULLABLE_METRIC_SCHEMAS = + Arrays.stream(VALID_METRIC_DATA_TYPES) + .collect(Collectors.toMap(dt -> dt, dt -> new Schema.SchemaBuilder() + .setSchemaName("testTable") + .setEnableColumnBasedNullHandling(true) + .addMetricField("myField", dt, f -> f.setNullable(true)) + .build())); + protected static final TableConfig SINGLE_FIELD_TABLE_CONFIG = new TableConfigBuilder(TableType.OFFLINE) .setTableName("testTable") .build(); @@ -75,6 +93,15 @@ protected FluentQueryTest.DeclaringTable givenSingleNullableFieldTable(FieldSpec protected FluentQueryTest.DeclaringTable givenSingleNullableFieldTable(FieldSpec.DataType dataType, boolean nullHandlingEnabled, @Nullable Consumer customize) { + return givenSingleNullableFieldTable(dataType, nullHandlingEnabled, FieldSpec.FieldType.DIMENSION, customize); + } + + protected FluentQueryTest.DeclaringTable givenSingleNullableFieldTable(FieldSpec.DataType dataType, + boolean nullHandlingEnabled, FieldSpec.FieldType fieldType, @Nullable Consumer customize) { + if (fieldType != FieldSpec.FieldType.DIMENSION && fieldType != FieldSpec.FieldType.METRIC) { + throw new IllegalArgumentException("Only METRIC and DIMENSION field types are supported"); + } + TableConfig tableConfig; if (customize == null) { tableConfig = SINGLE_FIELD_TABLE_CONFIG; @@ -88,9 +115,12 @@ protected FluentQueryTest.DeclaringTable givenSingleNullableFieldTable(FieldSpec tableConfig = builder.build(); } + Schema schema = fieldType == FieldSpec.FieldType.DIMENSION + ? SINGLE_FIELD_NULLABLE_DIMENSION_SCHEMAS.get(dataType) + : SINGLE_FIELD_NULLABLE_METRIC_SCHEMAS.get(dataType); return FluentQueryTest.withBaseDir(_baseDir) .withNullHandling(nullHandlingEnabled) - .givenTable(SINGLE_FIELD_NULLABLE_SCHEMAS.get(dataType), tableConfig); + .givenTable(schema, tableConfig); } protected FluentQueryTest.DeclaringTable givenSingleNullableIntFieldTable(boolean nullHandling) { @@ -118,4 +148,30 @@ void destroyBaseDir() FileUtils.deleteDirectory(_baseDir); } } + + class DataTypeScenario { + private final FieldSpec.DataType _dataType; + + public DataTypeScenario(FieldSpec.DataType dataType) { + _dataType = dataType; + } + + public FieldSpec.DataType getDataType() { + return _dataType; + } + + public FluentQueryTest.DeclaringTable getDeclaringTable(boolean nullHandlingEnabled) { + return givenSingleNullableFieldTable(_dataType, nullHandlingEnabled); + } + + public FluentQueryTest.DeclaringTable getDeclaringTable(boolean nullHandlingEnabled, + FieldSpec.FieldType fieldType) { + return givenSingleNullableFieldTable(_dataType, nullHandlingEnabled, fieldType, null); + } + + @Override + public String toString() { + return "DataTypeScenario{" + "dt=" + _dataType + '}'; + } + } } diff --git a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/ArrayAggFunctionTest.java b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/ArrayAggFunctionTest.java new file mode 100644 index 00000000000..9cb85467cf8 --- /dev/null +++ b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/ArrayAggFunctionTest.java @@ -0,0 +1,661 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.core.query.aggregation.function; + +import org.apache.pinot.spi.data.FieldSpec; +import org.testng.annotations.Test; + + +public class ArrayAggFunctionTest extends AbstractAggregationFunctionTest { + + @Test + void aggregationAllNullsWithNullHandlingDisabled() { + new DataTypeScenario(FieldSpec.DataType.INT).getDeclaringTable(false) + .onFirstInstance("myField", + "null", + "null", + "null" + ).andOnSecondInstance("myField", + "null", + "null" + ).whenQuery("select arrayagg(myField, 'INT') from testTable") + .thenResultIs(new Object[]{new int[]{Integer.MIN_VALUE, Integer.MIN_VALUE, Integer.MIN_VALUE, Integer.MIN_VALUE, + Integer.MIN_VALUE}}); + } + + @Test + void aggregationAllNullsWithNullHandlingEnabled() { + new DataTypeScenario(FieldSpec.DataType.LONG).getDeclaringTable(true) + .onFirstInstance("myField", + "null", + "null", + "null" + ).andOnSecondInstance("myField", + "null", + "null" + ).whenQuery("select arrayagg(myField, 'LONG') from testTable") + .thenResultIs(new Object[]{new long[0]}); + } + + @Test + void aggregationGroupBySVAllNullsWithNullHandlingDisabled() { + new DataTypeScenario(FieldSpec.DataType.FLOAT).getDeclaringTable(false) + .onFirstInstance("myField", + "null", + "null", + "null" + ).andOnSecondInstance("myField", + "null", + "null" + ).whenQuery("select 'literal', arrayagg(myField, 'FLOAT') from testTable group by 'literal'") + .thenResultIs(new Object[]{"literal", new float[]{Float.NEGATIVE_INFINITY, Float.NEGATIVE_INFINITY, + Float.NEGATIVE_INFINITY, Float.NEGATIVE_INFINITY, Float.NEGATIVE_INFINITY}}); + } + + @Test + void aggregationGroupBySVAllNullsWithNullHandlingEnabled() { + new DataTypeScenario(FieldSpec.DataType.DOUBLE).getDeclaringTable(true) + .onFirstInstance("myField", + "null", + "null", + "null" + ).andOnSecondInstance("myField", + "null", + "null" + ).whenQuery("select 'literal', arrayagg(myField, 'DOUBLE') from testTable group by 'literal'") + .thenResultIs(new Object[]{"literal", new double[0]}); + } + + @Test + void aggregationIntWithNullHandlingDisabled() { + // Use repeated segment values because order of processing across segments isn't deterministic and not relevant + // to this test + new DataTypeScenario(FieldSpec.DataType.INT).getDeclaringTable(false) + .onFirstInstance("myField", + "1", + "null", + "2" + ).andOnSecondInstance("myField", + "1", + "null", + "2" + ).whenQuery("select arrayagg(myField, 'INT') from testTable") + .thenResultIs(new Object[]{new int[]{1, Integer.MIN_VALUE, 2, 1, Integer.MIN_VALUE, 2}}); + } + + @Test + void aggregationIntWithNullHandlingEnabled() { + new DataTypeScenario(FieldSpec.DataType.INT).getDeclaringTable(true) + .onFirstInstance("myField", + "1", + "null", + "2" + ).andOnSecondInstance("myField", + "1", + "null", + "2" + ).whenQuery("select arrayagg(myField, 'INT') from testTable") + .thenResultIs(new Object[]{new int[]{1, 2, 1, 2}}); + } + + @Test + void aggregationDistinctIntWithNullHandlingDisabled() { + // Use a single value in the segment because ordering is currently not deterministic due to the use of a hashset in + // distinct array agg + new DataTypeScenario(FieldSpec.DataType.INT).getDeclaringTable(false) + .onFirstInstance("myField", + "null", + "null" + ).whenQuery("select arrayagg(myField, 'INT', true) from testTable") + .thenResultIs(new Object[]{new int[]{Integer.MIN_VALUE}}); + } + + @Test + void aggregationDistinctIntWithNullHandlingEnabled() { + new DataTypeScenario(FieldSpec.DataType.INT).getDeclaringTable(true) + .onFirstInstance("myField", + "1", + "null", + "1" + ).whenQuery("select arrayagg(myField, 'INT', true) from testTable") + .thenResultIs(new Object[]{new int[]{1}}); + } + + @Test + void aggregationGroupBySVIntWithNullHandlingDisabled() { + new DataTypeScenario(FieldSpec.DataType.INT).getDeclaringTable(false) + .onFirstInstance("myField", + "1", + "2", + "null" + ).andOnSecondInstance("myField", + "1", + "2", + "null" + ).whenQuery("select myField, arrayagg(myField, 'INT') from testTable group by myField") + .thenResultIs(new Object[]{1, new int[]{1, 1}}, new Object[]{2, new int[]{2, 2}}, + new Object[]{Integer.MIN_VALUE, new int[]{Integer.MIN_VALUE, Integer.MIN_VALUE}}); + } + + @Test + void aggregationGroupBySVIntWithNullHandlingEnabled() { + new DataTypeScenario(FieldSpec.DataType.INT).getDeclaringTable(true) + .onFirstInstance("myField", + "1", + "2", + "null" + ).andOnSecondInstance("myField", + "1", + "2", + "null" + ).whenQuery("select myField, arrayagg(myField, 'INT') from testTable group by myField") + .thenResultIs(new Object[]{1, new int[]{1, 1}}, new Object[]{2, new int[]{2, 2}}, + new Object[]{null, new int[0]}); + } + + @Test + void aggregationDistinctGroupBySVIntWithNullHandlingDisabled() { + new DataTypeScenario(FieldSpec.DataType.INT).getDeclaringTable(false) + .onFirstInstance("myField", + "1", + "2", + "null" + ).andOnSecondInstance("myField", + "1", + "2", + "null" + ).whenQuery("select myField, arrayagg(myField, 'INT', true) from testTable group by myField") + .thenResultIs(new Object[]{1, new int[]{1}}, new Object[]{2, new int[]{2}}, + new Object[]{Integer.MIN_VALUE, new int[]{Integer.MIN_VALUE}}); + } + + @Test + void aggregationDistinctGroupBySVIntWithNullHandlingEnabled() { + new DataTypeScenario(FieldSpec.DataType.INT).getDeclaringTable(true) + .onFirstInstance("myField", + "1", + "2", + "null" + ).andOnSecondInstance("myField", + "1", + "2", + "null" + ).whenQuery("select myField, arrayagg(myField, 'INT', true) from testTable group by myField") + .thenResultIs(new Object[]{1, new int[]{1}}, new Object[]{2, new int[]{2}}, + new Object[]{null, new int[0]}); + } + + @Test + void aggregationLongWithNullHandlingDisabled() { + new DataTypeScenario(FieldSpec.DataType.LONG).getDeclaringTable(false) + .onFirstInstance("myField", + "1", + "null", + "2" + ).andOnSecondInstance("myField", + "1", + "null", + "2" + ).whenQuery("select arrayagg(myField, 'LONG') from testTable") + .thenResultIs(new Object[]{new long[]{1, Long.MIN_VALUE, 2, 1, Long.MIN_VALUE, 2}}); + } + + @Test + void aggregationLongWithNullHandlingEnabled() { + new DataTypeScenario(FieldSpec.DataType.LONG).getDeclaringTable(true) + .onFirstInstance("myField", + "1", + "null", + "2" + ).andOnSecondInstance("myField", + "1", + "null", + "2" + ).whenQuery("select arrayagg(myField, 'LONG') from testTable") + .thenResultIs(new Object[]{new long[]{1, 2, 1, 2}}); + } + + @Test + void aggregationDistinctLongWithNullHandlingDisabled() { + new DataTypeScenario(FieldSpec.DataType.LONG).getDeclaringTable(false) + .onFirstInstance("myField", + "null", + "null" + ).whenQuery("select arrayagg(myField, 'LONG', true) from testTable") + .thenResultIs(new Object[]{new long[]{Long.MIN_VALUE}}); + } + + @Test + void aggregationDistinctLongWithNullHandlingEnabled() { + new DataTypeScenario(FieldSpec.DataType.LONG).getDeclaringTable(true) + .onFirstInstance("myField", + "1", + "null", + "1" + ).whenQuery("select arrayagg(myField, 'LONG', true) from testTable") + .thenResultIs(new Object[]{new long[]{1}}); + } + + @Test + void aggregationGroupBySVLongWithNullHandlingDisabled() { + new DataTypeScenario(FieldSpec.DataType.LONG).getDeclaringTable(false) + .onFirstInstance("myField", + "1", + "2", + "null" + ).andOnSecondInstance("myField", + "1", + "2", + "null" + ).whenQuery("select myField, arrayagg(myField, 'LONG') from testTable group by myField") + .thenResultIs(new Object[]{1L, new long[]{1, 1}}, new Object[]{2L, new long[]{2, 2}}, + new Object[]{Long.MIN_VALUE, new long[]{Long.MIN_VALUE, Long.MIN_VALUE}}); + } + + @Test + void aggregationGroupBySVLongWithNullHandlingEnabled() { + new DataTypeScenario(FieldSpec.DataType.LONG).getDeclaringTable(true) + .onFirstInstance("myField", + "1", + "2", + "null" + ).andOnSecondInstance("myField", + "1", + "2", + "null" + ).whenQuery("select myField, arrayagg(myField, 'LONG') from testTable group by myField") + .thenResultIs(new Object[]{1L, new long[]{1, 1}}, new Object[]{2L, new long[]{2, 2}}, + new Object[]{null, new long[0]}); + } + + @Test + void aggregationDistinctGroupBySVLongWithNullHandlingDisabled() { + new DataTypeScenario(FieldSpec.DataType.LONG).getDeclaringTable(false) + .onFirstInstance("myField", + "1", + "2", + "null" + ).andOnSecondInstance("myField", + "1", + "2", + "null" + ).whenQuery("select myField, arrayagg(myField, 'LONG', true) from testTable group by myField") + .thenResultIs(new Object[]{1L, new long[]{1}}, new Object[]{2L, new long[]{2}}, + new Object[]{Long.MIN_VALUE, new long[]{Long.MIN_VALUE}}); + } + + @Test + void aggregationDistinctGroupBySVLongWithNullHandlingEnabled() { + new DataTypeScenario(FieldSpec.DataType.LONG).getDeclaringTable(true) + .onFirstInstance("myField", + "1", + "2", + "null" + ).andOnSecondInstance("myField", + "1", + "2", + "null" + ).whenQuery("select myField, arrayagg(myField, 'LONG', true) from testTable group by myField") + .thenResultIs(new Object[]{1L, new long[]{1}}, new Object[]{2L, new long[]{2}}, + new Object[]{null, new long[0]}); + } + + @Test + void aggregationFloatWithNullHandlingDisabled() { + new DataTypeScenario(FieldSpec.DataType.FLOAT).getDeclaringTable(false) + .onFirstInstance("myField", + "1.0", + "null", + "2.0" + ).andOnSecondInstance("myField", + "1.0", + "null", + "2.0" + ).whenQuery("select arrayagg(myField, 'FLOAT') from testTable") + .thenResultIs(new Object[]{new float[]{1.0f, Float.NEGATIVE_INFINITY, 2.0f, 1.0f, Float.NEGATIVE_INFINITY, + 2.0f}}); + } + + @Test + void aggregationFloatWithNullHandlingEnabled() { + new DataTypeScenario(FieldSpec.DataType.FLOAT).getDeclaringTable(true) + .onFirstInstance("myField", + "1.0", + "null", + "2.0" + ).andOnSecondInstance("myField", + "1.0", + "null", + "2.0" + ).whenQuery("select arrayagg(myField, 'FLOAT') from testTable") + .thenResultIs(new Object[]{new float[]{1.0f, 2.0f, 1.0f, 2.0f}}); + } + + @Test + void aggregationDistinctFloatWithNullHandlingDisabled() { + new DataTypeScenario(FieldSpec.DataType.FLOAT).getDeclaringTable(false) + .onFirstInstance("myField", + "null", + "null" + ).whenQuery("select arrayagg(myField, 'FLOAT', true) from testTable") + .thenResultIs(new Object[]{new float[]{Float.NEGATIVE_INFINITY}}); + } + + @Test + void aggregationDistinctFloatWithNullHandlingEnabled() { + new DataTypeScenario(FieldSpec.DataType.FLOAT).getDeclaringTable(true) + .onFirstInstance("myField", + "1.0", + "null", + "1.0" + ).whenQuery("select arrayagg(myField, 'FLOAT', true) from testTable") + .thenResultIs(new Object[]{new float[]{1.0f}}); + } + + @Test + void aggregationGroupBySVFloatWithNullHandlingDisabled() { + new DataTypeScenario(FieldSpec.DataType.FLOAT).getDeclaringTable(false) + .onFirstInstance("myField", + "null", + "1.0", + "2.0" + ).andOnSecondInstance("myField", + "null", + "1.0", + "2.0" + ).whenQuery("select myField, arrayagg(myField, 'FLOAT') from testTable group by myField") + .thenResultIs(new Object[]{Float.NEGATIVE_INFINITY, + new float[]{Float.NEGATIVE_INFINITY, Float.NEGATIVE_INFINITY}}, + new Object[]{1.0f, new float[]{1.0f, 1.0f}}, new Object[]{2.0f, new float[]{2.0f, 2.0f}}); + } + + @Test + void aggregationGroupBySVFloatWithNullHandlingEnabled() { + new DataTypeScenario(FieldSpec.DataType.FLOAT).getDeclaringTable(true) + .onFirstInstance("myField", + "null", + "1.0" + ).andOnSecondInstance("myField", + "null", + "1.0" + ).whenQuery("select myField, arrayagg(myField, 'FLOAT') from testTable group by myField") + .thenResultIs(new Object[]{null, new float[0]}, new Object[]{1.0f, new float[]{1.0f, 1.0f}}); + } + + @Test + void aggregationDistinctGroupBySVFloatWithNullHandlingDisabled() { + new DataTypeScenario(FieldSpec.DataType.FLOAT).getDeclaringTable(false) + .onFirstInstance("myField", + "null", + "1.0", + "2.0" + ).andOnSecondInstance("myField", + "null", + "1.0", + "2.0" + ).whenQuery("select myField, arrayagg(myField, 'FLOAT', true) from testTable group by myField") + .thenResultIs(new Object[]{Float.NEGATIVE_INFINITY, new float[]{Float.NEGATIVE_INFINITY}}, + new Object[]{1.0f, new float[]{1.0f}}, new Object[]{2.0f, new float[]{2.0f}}); + } + + @Test + void aggregationDistinctGroupBySVFloatWithNullHandlingEnabled() { + new DataTypeScenario(FieldSpec.DataType.FLOAT).getDeclaringTable(true) + .onFirstInstance("myField", + "null", + "1.0" + ).andOnSecondInstance("myField", + "null", + "1.0" + ).whenQuery("select myField, arrayagg(myField, 'FLOAT', true) from testTable group by myField") + .thenResultIs(new Object[]{null, new float[0]}, new Object[]{1.0f, new float[]{1.0f}}); + } + + @Test + void aggregationDoubleWithNullHandlingDisabled() { + new DataTypeScenario(FieldSpec.DataType.DOUBLE).getDeclaringTable(false) + .onFirstInstance("myField", + "1.0", + "null", + "2.0" + ).andOnSecondInstance("myField", + "1.0", + "null", + "2.0" + ).whenQuery("select arrayagg(myField, 'DOUBLE') from testTable") + .thenResultIs(new Object[]{new double[]{1.0, Double.NEGATIVE_INFINITY, 2.0, 1.0, Double.NEGATIVE_INFINITY, + 2.0}}); + } + + @Test + void aggregationDoubleWithNullHandlingEnabled() { + new DataTypeScenario(FieldSpec.DataType.DOUBLE).getDeclaringTable(true) + .onFirstInstance("myField", + "1.0", + "null", + "2.0" + ).andOnSecondInstance("myField", + "1.0", + "null", + "2.0" + ).whenQuery("select arrayagg(myField, 'DOUBLE') from testTable") + .thenResultIs(new Object[]{new double[]{1.0, 2.0, 1.0, 2.0}}); + } + + @Test + void aggregationDistinctDoubleWithNullHandlingDisabled() { + new DataTypeScenario(FieldSpec.DataType.DOUBLE).getDeclaringTable(false) + .onFirstInstance("myField", + "null", + "null" + ).whenQuery("select arrayagg(myField, 'DOUBLE', true) from testTable") + .thenResultIs(new Object[]{new double[]{Double.NEGATIVE_INFINITY}}); + } + + @Test + void aggregationDistinctDoubleWithNullHandlingEnabled() { + new DataTypeScenario(FieldSpec.DataType.DOUBLE).getDeclaringTable(true) + .onFirstInstance("myField", + "1.0", + "null", + "1.0" + ).whenQuery("select arrayagg(myField, 'DOUBLE', true) from testTable") + .thenResultIs(new Object[]{new double[]{1.0}}); + } + + @Test + void aggregationGroupBySVDoubleWithNullHandlingDisabled() { + new DataTypeScenario(FieldSpec.DataType.DOUBLE).getDeclaringTable(false) + .onFirstInstance("myField", + "null", + "1.0", + "2.0" + ).andOnSecondInstance("myField", + "null", + "1.0", + "2.0" + ).whenQuery("select myField, arrayagg(myField, 'DOUBLE') from testTable group by myField") + .thenResultIs(new Object[]{Double.NEGATIVE_INFINITY, new double[]{Double.NEGATIVE_INFINITY, + Double.NEGATIVE_INFINITY}}, new Object[]{1.0, new double[]{1.0, 1.0}}, + new Object[]{2.0, new double[]{2.0, 2.0}}); + } + + @Test + void aggregationGroupBySVDoubleWithNullHandlingEnabled() { + new DataTypeScenario(FieldSpec.DataType.DOUBLE).getDeclaringTable(true) + .onFirstInstance("myField", + "null", + "1.0", + "2.0" + ).andOnSecondInstance("myField", + "null", + "1.0", + "2.0" + ).whenQuery("select myField, arrayagg(myField, 'DOUBLE') from testTable group by myField") + .thenResultIs(new Object[]{null, new double[0]}, new Object[]{1.0, new double[]{1.0, 1.0}}, + new Object[]{2.0, new double[]{2.0, 2.0}}); + } + + @Test + void aggregationDistinctGroupBySVDoubleWithNullHandlingDisabled() { + new DataTypeScenario(FieldSpec.DataType.DOUBLE).getDeclaringTable(false) + .onFirstInstance("myField", + "null", + "1.0", + "2.0" + ).andOnSecondInstance("myField", + "null", + "1.0", + "2.0" + ).whenQuery("select myField, arrayagg(myField, 'DOUBLE', true) from testTable group by myField") + .thenResultIs(new Object[]{Double.NEGATIVE_INFINITY, new double[]{Double.NEGATIVE_INFINITY}}, + new Object[]{1.0, new double[]{1.0}}, new Object[]{2.0, new double[]{2.0}}); + } + + @Test + void aggregationDistinctGroupBySVDoubleWithNullHandlingEnabled() { + new DataTypeScenario(FieldSpec.DataType.DOUBLE).getDeclaringTable(true) + .onFirstInstance("myField", + "null", + "1.0", + "2.0" + ).andOnSecondInstance("myField", + "null", + "1.0", + "2.0" + ).whenQuery("select myField, arrayagg(myField, 'DOUBLE', true) from testTable group by myField") + .thenResultIs(new Object[]{null, new double[0]}, new Object[]{1.0, new double[]{1.0}}, + new Object[]{2.0, new double[]{2.0}}); + } + + @Test + void aggregationStringWithNullHandlingDisabled() { + new DataTypeScenario(FieldSpec.DataType.STRING).getDeclaringTable(false) + .onFirstInstance("myField", + "a", + "null", + "b" + ).andOnSecondInstance("myField", + "a", + "null", + "b" + ).whenQuery("select arrayagg(myField, 'STRING') from testTable") + .thenResultIs(new Object[]{new String[]{"a", "null", "b", "a", "null", "b"}}); + } + + @Test + void aggregationStringWithNullHandlingEnabled() { + new DataTypeScenario(FieldSpec.DataType.STRING).getDeclaringTable(true) + .onFirstInstance("myField", + "a", + "null", + "b" + ).andOnSecondInstance("myField", + "a", + "null", + "b" + ).whenQuery("select arrayagg(myField, 'STRING') from testTable") + .thenResultIs(new Object[]{new String[]{"a", "b", "a", "b"}}); + } + + @Test + void aggregationDistinctStringWithNullHandlingDisabled() { + new DataTypeScenario(FieldSpec.DataType.STRING).getDeclaringTable(false) + .onFirstInstance("myField", + "null", + "null" + ).whenQuery("select arrayagg(myField, 'STRING', true) from testTable") + .thenResultIs(new Object[]{new String[]{"null"}}); + } + + @Test + void aggregationDistinctStringWithNullHandlingEnabled() { + new DataTypeScenario(FieldSpec.DataType.STRING).getDeclaringTable(true) + .onFirstInstance("myField", + "a", + "null", + "a" + ).whenQuery("select arrayagg(myField, 'STRING', true) from testTable") + .thenResultIs(new Object[]{new String[]{"a"}}); + } + + @Test + void aggregationGroupBySVStringWithNullHandlingDisabled() { + new DataTypeScenario(FieldSpec.DataType.STRING).getDeclaringTable(false) + .onFirstInstance("myField", + "a", + "b", + "null" + ).andOnSecondInstance("myField", + "a", + "b", + "null" + ).whenQuery("select myField, arrayagg(myField, 'STRING') from testTable group by myField") + .thenResultIs(new Object[]{"a", new String[]{"a", "a"}}, new Object[]{"b", new String[]{"b", "b"}}, + new Object[]{"null", new String[]{"null", "null"}}); + } + + @Test + void aggregationGroupBySVStringWithNullHandlingEnabled() { + new DataTypeScenario(FieldSpec.DataType.STRING).getDeclaringTable(true) + .onFirstInstance("myField", + "a", + "b", + "null" + ).andOnSecondInstance("myField", + "a", + "b", + "null" + ).whenQuery("select myField, arrayagg(myField, 'STRING') from testTable group by myField") + .thenResultIs(new Object[]{"a", new String[]{"a", "a"}}, new Object[]{"b", new String[]{"b", "b"}}, + new Object[]{null, new String[0]}); + } + + @Test + void aggregationDistinctGroupBySVStringWithNullHandlingDisabled() { + new DataTypeScenario(FieldSpec.DataType.STRING).getDeclaringTable(false) + .onFirstInstance("myField", + "a", + "b", + "null" + ).andOnSecondInstance("myField", + "a", + "b", + "null" + ).whenQuery("select myField, arrayagg(myField, 'STRING', true) from testTable group by myField") + .thenResultIs(new Object[]{"a", new String[]{"a"}}, new Object[]{"b", new String[]{"b"}}, + new Object[]{"null", new String[]{"null"}}); + } + + @Test + void aggregationDistinctGroupBySVStringWithNullHandlingEnabled() { + new DataTypeScenario(FieldSpec.DataType.STRING).getDeclaringTable(true) + .onFirstInstance("myField", + "a", + "b", + "null" + ).andOnSecondInstance("myField", + "a", + "b", + "null" + ).whenQuery("select myField, arrayagg(myField, 'STRING', true) from testTable group by myField") + .thenResultIs(new Object[]{"a", new String[]{"a"}}, new Object[]{"b", new String[]{"b"}}, + new Object[]{null, new String[0]}); + } +} diff --git a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/AvgAggregationFunctionTest.java b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/AvgAggregationFunctionTest.java new file mode 100644 index 00000000000..e52eb318004 --- /dev/null +++ b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/AvgAggregationFunctionTest.java @@ -0,0 +1,144 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.core.query.aggregation.function; + +import org.apache.pinot.spi.data.FieldSpec; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + + +public class AvgAggregationFunctionTest extends AbstractAggregationFunctionTest { + + @DataProvider(name = "scenarios") + Object[] scenarios() { + return new Object[] { + new DataTypeScenario(FieldSpec.DataType.INT), + new DataTypeScenario(FieldSpec.DataType.LONG), + new DataTypeScenario(FieldSpec.DataType.FLOAT), + new DataTypeScenario(FieldSpec.DataType.DOUBLE), + new DataTypeScenario(FieldSpec.DataType.BIG_DECIMAL) + }; + } + + @Test(dataProvider = "scenarios") + void aggregationAllNullsWithNullHandlingDisabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(false, FieldSpec.FieldType.METRIC) + .onFirstInstance("myField", + "null", + "null" + ).andOnSecondInstance("myField", + "null" + ).whenQuery("select avg(myField) from testTable") + .thenResultIs("DOUBLE", + String.valueOf(FieldSpec.getDefaultNullValue(FieldSpec.FieldType.METRIC, scenario.getDataType(), null))); + } + + @Test(dataProvider = "scenarios") + void aggregationAllNullsWithNullHandlingEnabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(true) + .onFirstInstance("myField", + "null", + "null" + ).andOnSecondInstance("myField", + "null" + ).whenQuery("select avg(myField) from testTable") + .thenResultIs("DOUBLE", "null"); + } + + @Test(dataProvider = "scenarios") + void aggregationGroupBySVAllNullsWithNullHandlingDisabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(false, FieldSpec.FieldType.METRIC) + .onFirstInstance("myField", + "null", + "null" + ).andOnSecondInstance("myField", + "null" + ).whenQuery("select 'literal', avg(myField) from testTable group by 'literal'") + .thenResultIs("STRING | DOUBLE", "literal | " + + FieldSpec.getDefaultNullValue(FieldSpec.FieldType.METRIC, scenario.getDataType(), null)); + } + + @Test(dataProvider = "scenarios") + void aggregationGroupBySVAllNullsWithNullHandlingEnabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(true) + .onFirstInstance("myField", + "null", + "null" + ).andOnSecondInstance("myField", + "null" + ).whenQuery("select 'literal', avg(myField) from testTable group by 'literal'") + .thenResultIs("STRING | DOUBLE", "literal | null"); + } + + @Test(dataProvider = "scenarios") + void aggregationWithNullHandlingDisabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(false, FieldSpec.FieldType.METRIC) + .onFirstInstance("myField", + "7", + "null", + "5" + ).andOnSecondInstance("myField", + "null", + "3" + ).whenQuery("select avg(myField) from testTable") + .thenResultIs("DOUBLE", "3"); + } + + @Test(dataProvider = "scenarios") + void aggregationWithNullHandlingEnabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(true) + .onFirstInstance("myField", + "7", + "null", + "5" + ).andOnSecondInstance("myField", + "null", + "3" + ).whenQuery("select avg(myField) from testTable") + .thenResultIs("DOUBLE", "5"); + } + + @Test(dataProvider = "scenarios") + void aggregationGroupBySVWithNullHandlingDisabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(false, FieldSpec.FieldType.METRIC) + .onFirstInstance("myField", + "7", + "null", + "5" + ).andOnSecondInstance("myField", + "null", + "3" + ).whenQuery("select 'literal', avg(myField) from testTable group by 'literal'") + .thenResultIs("STRING | DOUBLE", "literal | 3"); + } + + @Test(dataProvider = "scenarios") + void aggregationGroupBySVWithNullHandlingEnabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(true) + .onFirstInstance("myField", + "7", + "null", + "5" + ).andOnSecondInstance("myField", + "null", + "3" + ).whenQuery("select 'literal', avg(myField) from testTable group by 'literal'") + .thenResultIs("STRING | DOUBLE", "literal | 5"); + } +} diff --git a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/BooleanAggregationFunctionTest.java b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/BooleanAggregationFunctionTest.java new file mode 100644 index 00000000000..2fda45469cb --- /dev/null +++ b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/BooleanAggregationFunctionTest.java @@ -0,0 +1,178 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.core.query.aggregation.function; + +import org.apache.pinot.spi.data.FieldSpec; +import org.testng.annotations.Test; + + +public class BooleanAggregationFunctionTest extends AbstractAggregationFunctionTest { + + @Test + void aggregationAllNullsWithNullHandlingDisabled() { + new DataTypeScenario(FieldSpec.DataType.BOOLEAN).getDeclaringTable(false) + .onFirstInstance("myField", + "null", + "null" + ).andOnSecondInstance("myField", + "null" + ).whenQuery("select bool_or(myField) from testTable") + .thenResultIs("BOOLEAN", "false"); + } + + @Test + void aggregationAllNullsWithNullHandlingEnabled() { + new DataTypeScenario(FieldSpec.DataType.BOOLEAN).getDeclaringTable(true) + .onFirstInstance("myField", + "null", + "null" + ).andOnSecondInstance("myField", + "null" + ).whenQuery("select bool_and(myField) from testTable") + .thenResultIs("BOOLEAN", "null"); + } + + @Test + void aggregationGroupBySVAllNullsWithNullHandlingDisabled() { + new DataTypeScenario(FieldSpec.DataType.BOOLEAN).getDeclaringTable(false) + .onFirstInstance("myField", + "null", + "null" + ).andOnSecondInstance("myField", + "null" + ).whenQuery("select 'literal', bool_and(myField) from testTable group by 'literal'") + .thenResultIs("STRING | BOOLEAN", "literal | false"); + } + + @Test + void aggregationGroupBySVAllNullsWithNullHandlingEnabled() { + new DataTypeScenario(FieldSpec.DataType.BOOLEAN).getDeclaringTable(true) + .onFirstInstance("myField", + "null", + "null" + ).andOnSecondInstance("myField", + "null" + ).whenQuery("select 'literal', bool_or(myField) from testTable group by 'literal'") + .thenResultIs("STRING | BOOLEAN", "literal | null"); + } + + @Test + void andAggregationWithNullHandlingDisabled() { + new DataTypeScenario(FieldSpec.DataType.BOOLEAN).getDeclaringTable(false) + .onFirstInstance("myField", + "null", + "true" + ).andOnSecondInstance("myField", + "null", + "true" + ).whenQuery("select bool_and(myField) from testTable") + .thenResultIs("BOOLEAN", "false"); + } + + @Test + void andAggregationWithNullHandlingEnabled() { + new DataTypeScenario(FieldSpec.DataType.BOOLEAN).getDeclaringTable(true) + .onFirstInstance("myField", + "null", + "true" + ).andOnSecondInstance("myField", + "null", + "null" + ).whenQuery("select bool_and(myField) from testTable") + .thenResultIs("BOOLEAN", "true"); + } + + @Test + void andAggregationGroupBySVWithNullHandlingDisabled() { + new DataTypeScenario(FieldSpec.DataType.BOOLEAN).getDeclaringTable(false) + .onFirstInstance("myField", + "null", + "true" + ).andOnSecondInstance("myField", + "null", + "true" + ).whenQuery("select 'literal', bool_and(myField) from testTable group by 'literal'") + .thenResultIs("STRING | BOOLEAN", "literal | false"); + } + + @Test + void andAggregationGroupBySVWithNullHandlingEnabled() { + new DataTypeScenario(FieldSpec.DataType.BOOLEAN).getDeclaringTable(true) + .onFirstInstance("myField", + "null", + "true" + ).andOnSecondInstance("myField", + "null", + "null" + ).whenQuery("select 'literal', bool_and(myField) from testTable group by 'literal'") + .thenResultIs("STRING | BOOLEAN", "literal | true"); + } + + @Test + void orAggregationWithNullHandlingDisabled() { + new DataTypeScenario(FieldSpec.DataType.BOOLEAN).getDeclaringTable(false) + .onFirstInstance("myField", + "null", + "false" + ).andOnSecondInstance("myField", + "null", + "false" + ).whenQuery("select bool_or(myField) from testTable") + .thenResultIs("BOOLEAN", "false"); + } + + @Test + void orAggregationWithNullHandlingEnabled() { + new DataTypeScenario(FieldSpec.DataType.BOOLEAN).getDeclaringTable(true) + .onFirstInstance("myField", + "null", + "true" + ).andOnSecondInstance("myField", + "null", + "null" + ).whenQuery("select bool_or(myField) from testTable") + .thenResultIs("BOOLEAN", "true"); + } + + @Test + void orAggregationGroupBySVWithNullHandlingDisabled() { + new DataTypeScenario(FieldSpec.DataType.BOOLEAN).getDeclaringTable(false) + .onFirstInstance("myField", + "null", + "true" + ).andOnSecondInstance("myField", + "null", + "true" + ).whenQuery("select 'literal', bool_or(myField) from testTable group by 'literal'") + .thenResultIs("STRING | BOOLEAN", "literal | true"); + } + + @Test + void orAggregationGroupBySVWithNullHandlingEnabled() { + new DataTypeScenario(FieldSpec.DataType.BOOLEAN).getDeclaringTable(true) + .onFirstInstance("myField", + "null", + "false" + ).andOnSecondInstance("myField", + "null", + "false" + ).whenQuery("select 'literal', bool_or(myField) from testTable group by 'literal'") + .thenResultIs("STRING | BOOLEAN", "literal | false"); + } +} diff --git a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/CountAggregationFunctionTest.java b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/CountAggregationFunctionTest.java index 5f60e756a1d..7a0d5f0da68 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/CountAggregationFunctionTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/CountAggregationFunctionTest.java @@ -21,6 +21,7 @@ import org.apache.pinot.queries.FluentQueryTest; import org.apache.pinot.spi.data.FieldSpec; +import org.testng.annotations.DataProvider; import org.testng.annotations.Test; @@ -30,7 +31,7 @@ public class CountAggregationFunctionTest extends AbstractAggregationFunctionTes public void list() { FluentQueryTest.withBaseDir(_baseDir) .withNullHandling(false) - .givenTable(SINGLE_FIELD_NULLABLE_SCHEMAS.get(FieldSpec.DataType.INT), SINGLE_FIELD_TABLE_CONFIG) + .givenTable(SINGLE_FIELD_NULLABLE_DIMENSION_SCHEMAS.get(FieldSpec.DataType.INT), SINGLE_FIELD_TABLE_CONFIG) .onFirstInstance( new Object[] {1} ) @@ -50,7 +51,7 @@ public void list() { public void listNullHandlingEnabled() { FluentQueryTest.withBaseDir(_baseDir) .withNullHandling(true) - .givenTable(SINGLE_FIELD_NULLABLE_SCHEMAS.get(FieldSpec.DataType.INT), SINGLE_FIELD_TABLE_CONFIG) + .givenTable(SINGLE_FIELD_NULLABLE_DIMENSION_SCHEMAS.get(FieldSpec.DataType.INT), SINGLE_FIELD_TABLE_CONFIG) .onFirstInstance( new Object[] {1} ) @@ -70,7 +71,7 @@ public void listNullHandlingEnabled() { public void countNullWhenHandlingDisabled() { FluentQueryTest.withBaseDir(_baseDir) .withNullHandling(false) - .givenTable(SINGLE_FIELD_NULLABLE_SCHEMAS.get(FieldSpec.DataType.INT), SINGLE_FIELD_TABLE_CONFIG) + .givenTable(SINGLE_FIELD_NULLABLE_DIMENSION_SCHEMAS.get(FieldSpec.DataType.INT), SINGLE_FIELD_TABLE_CONFIG) .onFirstInstance( "myField", "1" @@ -93,7 +94,7 @@ public void countNullWhenHandlingDisabled() { public void countNullWhenHandlingEnabled() { FluentQueryTest.withBaseDir(_baseDir) .withNullHandling(true) - .givenTable(SINGLE_FIELD_NULLABLE_SCHEMAS.get(FieldSpec.DataType.INT), SINGLE_FIELD_TABLE_CONFIG) + .givenTable(SINGLE_FIELD_NULLABLE_DIMENSION_SCHEMAS.get(FieldSpec.DataType.INT), SINGLE_FIELD_TABLE_CONFIG) .onFirstInstance( "myField", "1" @@ -116,7 +117,7 @@ public void countNullWhenHandlingEnabled() { public void countStarNullWhenHandlingDisabled() { FluentQueryTest.withBaseDir(_baseDir) .withNullHandling(false) - .givenTable(SINGLE_FIELD_NULLABLE_SCHEMAS.get(FieldSpec.DataType.INT), SINGLE_FIELD_TABLE_CONFIG) + .givenTable(SINGLE_FIELD_NULLABLE_DIMENSION_SCHEMAS.get(FieldSpec.DataType.INT), SINGLE_FIELD_TABLE_CONFIG) .onFirstInstance( "myField", "1" @@ -138,7 +139,7 @@ public void countStarNullWhenHandlingDisabled() { public void countStarNullWhenHandlingEnabled() { FluentQueryTest.withBaseDir(_baseDir) .withNullHandling(true) - .givenTable(SINGLE_FIELD_NULLABLE_SCHEMAS.get(FieldSpec.DataType.INT), SINGLE_FIELD_TABLE_CONFIG) + .givenTable(SINGLE_FIELD_NULLABLE_DIMENSION_SCHEMAS.get(FieldSpec.DataType.INT), SINGLE_FIELD_TABLE_CONFIG) .onFirstInstance( "myField", "1" @@ -153,6 +154,55 @@ public void countStarNullWhenHandlingEnabled() { "1 | 1", "2 | 1", "null | 1" - );; + ); + } + + @Test(dataProvider = "nullHandlingEnabled") + public void countStarWithoutGroupBy(boolean nullHandlingEnabled) { + FluentQueryTest.withBaseDir(_baseDir) + .withNullHandling(nullHandlingEnabled) + .givenTable(SINGLE_FIELD_NULLABLE_DIMENSION_SCHEMAS.get(FieldSpec.DataType.INT), SINGLE_FIELD_TABLE_CONFIG) + .onFirstInstance( + "myField", + "1", + "2", + "null" + ) + .andOnSecondInstance( + "myField", + "null", + "null" + ) + .whenQuery("select COUNT(*) from testTable") + // COUNT(*) result should be the same regardless of whether null handling is enabled or not + .thenResultIs("LONG", "5"); + } + + @Test(dataProvider = "nullHandlingEnabled") + public void countLiteralWithoutGroupBy(boolean nullHandlingEnabled) { + FluentQueryTest.withBaseDir(_baseDir) + .withNullHandling(nullHandlingEnabled) + .givenTable(SINGLE_FIELD_NULLABLE_DIMENSION_SCHEMAS.get(FieldSpec.DataType.INT), SINGLE_FIELD_TABLE_CONFIG) + .onFirstInstance( + "myField", + "1", + "2", + "null" + ) + .andOnSecondInstance( + "myField", + "null", + "null" + ) + .whenQuery("select COUNT('literal') from testTable") + // COUNT(*) result should be the same regardless of whether null handling is enabled or not + .thenResultIs("LONG", "5"); + } + + @DataProvider(name = "nullHandlingEnabled") + public Object[][] nullHandlingEnabled() { + return new Object[][]{ + {false}, {true} + }; } } diff --git a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/DistinctAggregationFunctionTest.java b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/DistinctAggregationFunctionTest.java new file mode 100644 index 00000000000..2067a4c9147 --- /dev/null +++ b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/DistinctAggregationFunctionTest.java @@ -0,0 +1,281 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.core.query.aggregation.function; + +import org.apache.pinot.spi.data.FieldSpec; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + + +public class DistinctAggregationFunctionTest extends AbstractAggregationFunctionTest { + + @DataProvider(name = "scenarios") + Object[] scenarios() { + return new Object[] { + new DataTypeScenario(FieldSpec.DataType.INT), + new DataTypeScenario(FieldSpec.DataType.LONG), + new DataTypeScenario(FieldSpec.DataType.FLOAT), + new DataTypeScenario(FieldSpec.DataType.DOUBLE) + }; + } + + @Test(dataProvider = "scenarios") + void distinctCountAggregationAllNullsWithNullHandlingDisabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(false) + .onFirstInstance("myField", + "null", + "null", + "null" + ).andOnSecondInstance("myField", + "null", + "null", + "null" + ).whenQuery("select count(distinct myField) from testTable") + .thenResultIs("INTEGER", "1"); + } + + @Test(dataProvider = "scenarios") + void distinctCountAggregationAllNullsWithNullHandlingEnabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(true) + .onFirstInstance("myField", + "null", + "null", + "null" + ).andOnSecondInstance("myField", + "null", + "null", + "null" + ).whenQuery("select count(distinct myField) from testTable") + .thenResultIs("INTEGER", "0"); + } + + @Test(dataProvider = "scenarios") + void distinctCountAggregationGroupBySVAllNullsWithNullHandlingDisabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(false) + .onFirstInstance("myField", + "null", + "null", + "null" + ).andOnSecondInstance("myField", + "null", + "null", + "null" + ).whenQuery("select 'literal', count(distinct myField) from testTable group by 'literal'") + .thenResultIs("STRING | INTEGER", "literal | 1"); + } + + @Test(dataProvider = "scenarios") + void distinctCountAggregationGroupBySVAllNullsWithNullHandlingEnabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(true) + .onFirstInstance("myField", + "null", + "null", + "null" + ).andOnSecondInstance("myField", + "null", + "null", + "null" + ).whenQuery("select 'literal', count(distinct myField) from testTable group by 'literal'") + .thenResultIs("STRING | INTEGER", "literal | 0"); + } + + @Test(dataProvider = "scenarios") + void distinctCountAggregationWithNullHandlingDisabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(false) + .onFirstInstance("myField", + "null", + "1", + "2", + "2" + ).andOnSecondInstance("myField", + "null", + "null", + "null" + ).whenQuery("select count(distinct myField) from testTable") + .thenResultIs("INTEGER", "3"); + } + + @Test(dataProvider = "scenarios") + void distinctCountAggregationWithNullHandlingEnabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(true) + .onFirstInstance("myField", + "null", + "1", + "2", + "2" + ).andOnSecondInstance("myField", + "null", + "null", + "null" + ).whenQuery("select count(distinct myField) from testTable") + .thenResultIs("INTEGER", "2"); + } + + @Test(dataProvider = "scenarios") + void distinctCountAggregationGroupBySVWithNullHandlingDisabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(true) + .onFirstInstance("myField", + "null", + "1", + "2" + ).andOnSecondInstance("myField", + "2", + "3", + "2" + ).whenQuery("select 'literal', count(distinct myField) from testTable group by 'literal'") + .thenResultIs("STRING | INTEGER", "literal | 3"); + } + + @Test(dataProvider = "scenarios") + void distinctSumAggregationWithNullHandlingDisabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(false) + .onFirstInstance("myField", + "null", + "1", + "2", + "2" + ).andOnSecondInstance("myField", + "null", + "null", + "null" + ).whenQuery("select sum(distinct myField) from testTable") + .thenResultIs("DOUBLE", addToDefaultNullValue(scenario.getDataType(), 3)); + } + + @Test(dataProvider = "scenarios") + void distinctSumAggregationWithNullHandlingEnabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(true) + .onFirstInstance("myField", + "null", + "1", + "2", + "2" + ).andOnSecondInstance("myField", + "null", + "null", + "null" + ).whenQuery("select sum(distinct myField) from testTable") + .thenResultIs("DOUBLE", "3"); + } + + @Test(dataProvider = "scenarios") + void distinctSumAggregationGroupBySVWithNullHandlingDisabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(false) + .onFirstInstance("myField", + "null", + "1", + "2" + ).andOnSecondInstance("myField", + "2", + "3", + "2" + ).whenQuery("select 'literal', sum(distinct myField) from testTable group by 'literal'") + .thenResultIs("STRING | DOUBLE", "literal | " + addToDefaultNullValue(scenario.getDataType(), 6)); + } + + @Test(dataProvider = "scenarios") + void distinctSumAggregationGroupBySVWithNullHandlingEnabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(true) + .onFirstInstance("myField", + "null", + "1", + "2" + ).andOnSecondInstance("myField", + "2", + "3", + "2" + ).whenQuery("select 'literal', sum(distinct myField) from testTable group by 'literal'") + .thenResultIs("STRING | DOUBLE", "literal | 6"); + } + + @Test(dataProvider = "scenarios") + void distinctAvgAggregationWithNullHandlingDisabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(false, FieldSpec.FieldType.METRIC) + .onFirstInstance("myField", + "null", + "1", + "2" + ).andOnSecondInstance("myField", + "2", + "null", + "null" + ).whenQuery("select avg(distinct myField) from testTable") + .thenResultIs("DOUBLE", "1.0"); + } + + @Test(dataProvider = "scenarios") + void distinctAvgAggregationWithNullHandlingEnabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(true, FieldSpec.FieldType.METRIC) + .onFirstInstance("myField", + "null", + "1", + "2" + ).andOnSecondInstance("myField", + "2", + "null", + "null" + ).whenQuery("select avg(distinct myField) from testTable") + .thenResultIs("DOUBLE", "1.5"); + } + + @Test(dataProvider = "scenarios") + void distinctAvgAggregationGroupBySVWithNullHandlingDisabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(false, FieldSpec.FieldType.METRIC) + .onFirstInstance("myField", + "null", + "1", + "2" + ).andOnSecondInstance("myField", + "2", + "3", + "2" + ).whenQuery("select 'literal', avg(distinct myField) from testTable group by 'literal'") + .thenResultIs("STRING | DOUBLE", "literal | 1.5"); + } + + @Test(dataProvider = "scenarios") + void distinctAvgAggregationGroupBySVWithNullHandlingEnabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(true, FieldSpec.FieldType.METRIC) + .onFirstInstance("myField", + "null", + "1", + "2" + ).andOnSecondInstance("myField", + "2", + "3", + "2" + ).whenQuery("select 'literal', avg(distinct myField) from testTable group by 'literal'") + .thenResultIs("STRING | DOUBLE", "literal | 2.0"); + } + + private String addToDefaultNullValue(FieldSpec.DataType dataType, int addend) { + switch (dataType) { + case INT: + return String.valueOf(FieldSpec.DEFAULT_DIMENSION_NULL_VALUE_OF_INT + addend); + case LONG: + return String.valueOf(FieldSpec.DEFAULT_DIMENSION_NULL_VALUE_OF_LONG + addend); + case FLOAT: + return String.valueOf(FieldSpec.DEFAULT_DIMENSION_NULL_VALUE_OF_FLOAT + addend); + case DOUBLE: + return String.valueOf(FieldSpec.DEFAULT_DIMENSION_NULL_VALUE_OF_DOUBLE + addend); + default: + throw new IllegalArgumentException("Unsupported data type: " + dataType); + } + } +} diff --git a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/MaxAggregationFunctionTest.java b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/MaxAggregationFunctionTest.java new file mode 100644 index 00000000000..319f54e102d --- /dev/null +++ b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/MaxAggregationFunctionTest.java @@ -0,0 +1,148 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.core.query.aggregation.function; + +import org.apache.pinot.spi.data.FieldSpec; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + + +public class MaxAggregationFunctionTest extends AbstractAggregationFunctionTest { + + @DataProvider(name = "scenarios") + Object[] scenarios() { + return new Object[] { + new DataTypeScenario(FieldSpec.DataType.INT), + new DataTypeScenario(FieldSpec.DataType.LONG), + new DataTypeScenario(FieldSpec.DataType.FLOAT), + new DataTypeScenario(FieldSpec.DataType.DOUBLE), + new DataTypeScenario(FieldSpec.DataType.BIG_DECIMAL) + }; + } + + @Test(dataProvider = "scenarios") + void aggregationAllNullsWithNullHandlingDisabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(false) + .onFirstInstance("myField", + "null", + "null" + ).andOnSecondInstance("myField", + "null" + ).whenQuery("select max(myField) from testTable") + .thenResultIs("DOUBLE", + String.valueOf(FieldSpec.getDefaultNullValue(FieldSpec.FieldType.DIMENSION, scenario.getDataType(), null))); + } + + @Test(dataProvider = "scenarios") + void aggregationAllNullsWithNullHandlingEnabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(true) + .onFirstInstance("myField", + "null", + "null" + ).andOnSecondInstance("myField", + "null" + ).whenQuery("select max(myField) from testTable") + .thenResultIs("DOUBLE", "null"); + } + + @Test(dataProvider = "scenarios") + void aggregationGroupBySVAllNullsWithNullHandlingDisabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(false) + .onFirstInstance("myField", + "null", + "null" + ).andOnSecondInstance("myField", + "null" + ).whenQuery("select 'literal', max(myField) from testTable group by 'literal'") + .thenResultIs("STRING | DOUBLE", "literal | " + + FieldSpec.getDefaultNullValue(FieldSpec.FieldType.DIMENSION, scenario.getDataType(), null)); + } + + @Test(dataProvider = "scenarios") + void aggregationGroupBySVAllNullsWithNullHandlingEnabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(true) + .onFirstInstance("myField", + "null", + "null" + ).andOnSecondInstance("myField", + "null" + ).whenQuery("select 'literal', max(myField) from testTable group by 'literal'") + .thenResultIs("STRING | DOUBLE", "literal | null"); + } + + @Test(dataProvider = "scenarios") + void aggregationWithNullHandlingDisabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(false) + .onFirstInstance("myField", + "2", + "null", + "3" + ).andOnSecondInstance("myField", + "null", + "5", + "null" + ).whenQuery("select max(myField) from testTable") + .thenResultIs("DOUBLE", "5"); + } + + @Test(dataProvider = "scenarios") + void aggregationWithNullHandlingEnabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(true) + .onFirstInstance("myField", + "2", + "null", + "3" + ).andOnSecondInstance("myField", + "null", + "5", + "null" + ).whenQuery("select max(myField) from testTable") + .thenResultIs("DOUBLE", "5"); + } + + @Test(dataProvider = "scenarios") + void aggregationGroupBySVWithNullHandlingDisabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(false) + .onFirstInstance("myField", + "2", + "null", + "3" + ).andOnSecondInstance("myField", + "null", + "5", + "null" + ).whenQuery("select 'literal', max(myField) from testTable group by 'literal'") + .thenResultIs("STRING | DOUBLE", "literal | 5"); + } + + @Test(dataProvider = "scenarios") + void aggregationGroupBySVWithNullHandlingEnabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(true) + .onFirstInstance("myField", + "3", + "null", + "5" + ).andOnSecondInstance("myField", + "null", + "null", + "null" + ).whenQuery("select 'literal', max(myField) from testTable group by 'literal'") + .thenResultIs("STRING | DOUBLE", "literal | 5"); + } +} diff --git a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/MinAggregationFunctionTest.java b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/MinAggregationFunctionTest.java new file mode 100644 index 00000000000..79d3312fe1c --- /dev/null +++ b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/MinAggregationFunctionTest.java @@ -0,0 +1,150 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.core.query.aggregation.function; + +import org.apache.pinot.spi.data.FieldSpec; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + + +public class MinAggregationFunctionTest extends AbstractAggregationFunctionTest { + + @DataProvider(name = "scenarios") + Object[] scenarios() { + return new Object[] { + new DataTypeScenario(FieldSpec.DataType.INT), + new DataTypeScenario(FieldSpec.DataType.LONG), + new DataTypeScenario(FieldSpec.DataType.FLOAT), + new DataTypeScenario(FieldSpec.DataType.DOUBLE), + new DataTypeScenario(FieldSpec.DataType.BIG_DECIMAL) + }; + } + + @Test(dataProvider = "scenarios") + void aggregationAllNullsWithNullHandlingDisabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(false) + .onFirstInstance("myField", + "null", + "null" + ).andOnSecondInstance("myField", + "null" + ).whenQuery("select min(myField) from testTable") + .thenResultIs("DOUBLE", + String.valueOf(FieldSpec.getDefaultNullValue(FieldSpec.FieldType.DIMENSION, scenario.getDataType(), null))); + } + + @Test(dataProvider = "scenarios") + void aggregationAllNullsWithNullHandlingEnabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(true) + .onFirstInstance("myField", + "null", + "null" + ).andOnSecondInstance("myField", + "null" + ).whenQuery("select min(myField) from testTable") + .thenResultIs("DOUBLE", "null"); + } + + @Test(dataProvider = "scenarios") + void aggregationGroupBySVAllNullsWithNullHandlingDisabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(false) + .onFirstInstance("myField", + "null", + "null" + ).andOnSecondInstance("myField", + "null" + ).whenQuery("select 'literal', min(myField) from testTable group by 'literal'") + .thenResultIs("STRING | DOUBLE", "literal | " + + FieldSpec.getDefaultNullValue(FieldSpec.FieldType.DIMENSION, scenario.getDataType(), null)); + } + + @Test(dataProvider = "scenarios") + void aggregationGroupBySVAllNullsWithNullHandlingEnabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(true) + .onFirstInstance("myField", + "null", + "null" + ).andOnSecondInstance("myField", + "null" + ).whenQuery("select 'literal', min(myField) from testTable group by 'literal'") + .thenResultIs("STRING | DOUBLE", "literal | null"); + } + + @Test(dataProvider = "scenarios") + void aggregationWithNullHandlingDisabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(false) + .onFirstInstance("myField", + "5", + "null", + "3" + ).andOnSecondInstance("myField", + "null", + "2", + "null" + ).whenQuery("select min(myField) from testTable") + .thenResultIs("DOUBLE", + String.valueOf(FieldSpec.getDefaultNullValue(FieldSpec.FieldType.DIMENSION, scenario.getDataType(), null))); + } + + @Test(dataProvider = "scenarios") + void aggregationWithNullHandlingEnabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(true) + .onFirstInstance("myField", + "5", + "null", + "3" + ).andOnSecondInstance("myField", + "null", + "2", + "null" + ).whenQuery("select min(myField) from testTable") + .thenResultIs("DOUBLE", "2"); + } + + @Test(dataProvider = "scenarios") + void aggregationGroupBySVWithNullHandlingDisabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(false) + .onFirstInstance("myField", + "5", + "null", + "3" + ).andOnSecondInstance("myField", + "null", + "2", + "null" + ).whenQuery("select 'literal', min(myField) from testTable group by 'literal'") + .thenResultIs("STRING | DOUBLE", "literal | " + + FieldSpec.getDefaultNullValue(FieldSpec.FieldType.DIMENSION, scenario.getDataType(), null)); + } + + @Test(dataProvider = "scenarios") + void aggregationGroupBySVWithNullHandlingEnabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(true) + .onFirstInstance("myField", + "5", + "null", + "3" + ).andOnSecondInstance("myField", + "null", + "null", + "null" + ).whenQuery("select 'literal', min(myField) from testTable group by 'literal'") + .thenResultIs("STRING | DOUBLE", "literal | 3"); + } +} diff --git a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/MinMaxRangeAggregationFunctionTest.java b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/MinMaxRangeAggregationFunctionTest.java index 822399d66fb..bc019f630cf 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/MinMaxRangeAggregationFunctionTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/MinMaxRangeAggregationFunctionTest.java @@ -20,7 +20,6 @@ package org.apache.pinot.core.query.aggregation.function; import org.apache.pinot.common.utils.PinotDataType; -import org.apache.pinot.queries.FluentQueryTest; import org.apache.pinot.spi.data.FieldSpec; import org.testng.annotations.DataProvider; import org.testng.annotations.Test; @@ -31,30 +30,13 @@ public class MinMaxRangeAggregationFunctionTest extends AbstractAggregationFunct @DataProvider(name = "scenarios") Object[] scenarios() { return new Object[] { - new Scenario(FieldSpec.DataType.INT), - new Scenario(FieldSpec.DataType.LONG), - new Scenario(FieldSpec.DataType.FLOAT), - new Scenario(FieldSpec.DataType.DOUBLE), + new DataTypeScenario(FieldSpec.DataType.INT), + new DataTypeScenario(FieldSpec.DataType.LONG), + new DataTypeScenario(FieldSpec.DataType.FLOAT), + new DataTypeScenario(FieldSpec.DataType.DOUBLE), }; } - public class Scenario { - private final FieldSpec.DataType _dataType; - - public Scenario(FieldSpec.DataType dataType) { - _dataType = dataType; - } - - public FluentQueryTest.DeclaringTable getDeclaringTable(boolean nullHandlingEnabled) { - return givenSingleNullableFieldTable(_dataType, nullHandlingEnabled); - } - - @Override - public String toString() { - return "Scenario{" + "dt=" + _dataType + '}'; - } - } - String diffBetweenMinAnd9(FieldSpec.DataType dt) { switch (dt) { case INT: return "2.147483657E9"; @@ -66,7 +48,7 @@ String diffBetweenMinAnd9(FieldSpec.DataType dt) { } @Test(dataProvider = "scenarios") - void aggrWithoutNull(Scenario scenario) { + void aggrWithoutNull(DataTypeScenario scenario) { scenario.getDeclaringTable(false) .onFirstInstance("myField", "null", @@ -78,11 +60,11 @@ void aggrWithoutNull(Scenario scenario) { "null" ) .whenQuery("select minmaxrange(myField) from testTable") - .thenResultIs("DOUBLE", diffBetweenMinAnd9(scenario._dataType)); + .thenResultIs("DOUBLE", diffBetweenMinAnd9(scenario.getDataType())); } @Test(dataProvider = "scenarios") - void aggrWithNull(Scenario scenario) { + void aggrWithNull(DataTypeScenario scenario) { scenario.getDeclaringTable(true) .onFirstInstance("myField", "null", @@ -97,7 +79,7 @@ void aggrWithNull(Scenario scenario) { } @Test(dataProvider = "scenarios") - void aggrSvWithoutNull(Scenario scenario) { + void aggrSvWithoutNull(DataTypeScenario scenario) { scenario.getDeclaringTable(false) .onFirstInstance("myField", "null", @@ -108,11 +90,11 @@ void aggrSvWithoutNull(Scenario scenario) { "9", "null" ).whenQuery("select 'cte', minmaxrange(myField) from testTable group by 'cte'") - .thenResultIs("STRING | DOUBLE", "cte | " + diffBetweenMinAnd9(scenario._dataType)); + .thenResultIs("STRING | DOUBLE", "cte | " + diffBetweenMinAnd9(scenario.getDataType())); } @Test(dataProvider = "scenarios") - void aggrSvWithNull(Scenario scenario) { + void aggrSvWithNull(DataTypeScenario scenario) { scenario.getDeclaringTable(true) .onFirstInstance("myField", "null", @@ -137,12 +119,12 @@ String aggrSvSelfWithoutNullResult(FieldSpec.DataType dt) { } @Test(dataProvider = "scenarios") - void aggrSvSelfWithoutNull(Scenario scenario) { - PinotDataType pinotDataType = scenario._dataType == FieldSpec.DataType.INT - ? PinotDataType.INTEGER : PinotDataType.valueOf(scenario._dataType.name()); + void aggrSvSelfWithoutNull(DataTypeScenario scenario) { + PinotDataType pinotDataType = scenario.getDataType() == FieldSpec.DataType.INT + ? PinotDataType.INTEGER : PinotDataType.valueOf(scenario.getDataType().name()); Object defaultNullValue; - switch (scenario._dataType) { + switch (scenario.getDataType()) { case INT: defaultNullValue = Integer.MIN_VALUE; break; @@ -156,7 +138,7 @@ void aggrSvSelfWithoutNull(Scenario scenario) { defaultNullValue = Double.NEGATIVE_INFINITY; break; default: - throw new IllegalArgumentException("Unexpected scenario data type " + scenario._dataType); + throw new IllegalArgumentException("Unexpected scenario data type " + scenario.getDataType()); } scenario.getDeclaringTable(false) @@ -170,15 +152,15 @@ void aggrSvSelfWithoutNull(Scenario scenario) { "2" ).whenQuery("select myField, minmaxrange(myField) from testTable group by myField order by myField") .thenResultIs(pinotDataType + " | DOUBLE", - defaultNullValue + " | " + aggrSvSelfWithoutNullResult(scenario._dataType), + defaultNullValue + " | " + aggrSvSelfWithoutNullResult(scenario.getDataType()), "1 | 0", "2 | 0"); } @Test(dataProvider = "scenarios") - void aggrSvSelfWithNull(Scenario scenario) { - PinotDataType pinotDataType = scenario._dataType == FieldSpec.DataType.INT - ? PinotDataType.INTEGER : PinotDataType.valueOf(scenario._dataType.name()); + void aggrSvSelfWithNull(DataTypeScenario scenario) { + PinotDataType pinotDataType = scenario.getDataType() == FieldSpec.DataType.INT + ? PinotDataType.INTEGER : PinotDataType.valueOf(scenario.getDataType().name()); scenario.getDeclaringTable(true) .onFirstInstance("myField", diff --git a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/SumAggregationFunctionTest.java b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/SumAggregationFunctionTest.java new file mode 100644 index 00000000000..1bc333f0fb4 --- /dev/null +++ b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/SumAggregationFunctionTest.java @@ -0,0 +1,147 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.core.query.aggregation.function; + +import org.apache.pinot.spi.data.FieldSpec; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + + +public class SumAggregationFunctionTest extends AbstractAggregationFunctionTest { + + @DataProvider(name = "scenarios") + Object[] scenarios() { + return new Object[] { + new DataTypeScenario(FieldSpec.DataType.INT), + new DataTypeScenario(FieldSpec.DataType.LONG), + new DataTypeScenario(FieldSpec.DataType.FLOAT), + new DataTypeScenario(FieldSpec.DataType.DOUBLE), + new DataTypeScenario(FieldSpec.DataType.BIG_DECIMAL) + }; + } + + @Test(dataProvider = "scenarios") + void aggregationAllNullsWithNullHandlingDisabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(false, FieldSpec.FieldType.METRIC) + .onFirstInstance("myField", + "null", + "null" + ).andOnSecondInstance("myField", + "null" + ).whenQuery("select sum(myField) from testTable") + .thenResultIs("DOUBLE", + String.valueOf(FieldSpec.getDefaultNullValue(FieldSpec.FieldType.METRIC, scenario.getDataType(), null))); + } + + @Test(dataProvider = "scenarios") + void aggregationAllNullsWithNullHandlingEnabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(true, FieldSpec.FieldType.METRIC) + .onFirstInstance("myField", + "null", + "null" + ).andOnSecondInstance("myField", + "null" + ).whenQuery("select sum(myField) from testTable") + .thenResultIs("DOUBLE", "null"); + } + + @Test(dataProvider = "scenarios") + void aggregationGroupBySVAllNullsWithNullHandlingDisabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(false, FieldSpec.FieldType.METRIC) + .onFirstInstance("myField", + "null", + "null" + ).andOnSecondInstance("myField", + "null" + ).whenQuery("select 'literal', sum(myField) from testTable group by 'literal'") + .thenResultIs("STRING | DOUBLE", "literal | " + + FieldSpec.getDefaultNullValue(FieldSpec.FieldType.METRIC, scenario.getDataType(), null)); + } + + @Test(dataProvider = "scenarios") + void aggregationGroupBySVAllNullsWithNullHandlingEnabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(true, FieldSpec.FieldType.METRIC) + .onFirstInstance("myField", + "null", + "null" + ).andOnSecondInstance("myField", + "null" + ).whenQuery("select 'literal', sum(myField) from testTable group by 'literal'") + .thenResultIs("STRING | DOUBLE", "literal | null"); + } + + @Test(dataProvider = "scenarios") + void aggregationWithNullHandlingDisabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(false, FieldSpec.FieldType.METRIC) + .onFirstInstance("myField", + "3", + "null", + "5" + ).andOnSecondInstance("myField", + "null", + "null" + ).whenQuery("select sum(myField) from testTable") + .thenResultIs("DOUBLE", "8"); + } + + @Test(dataProvider = "scenarios") + void aggregationWithNullHandlingEnabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(true, FieldSpec.FieldType.METRIC) + .onFirstInstance("myField", + "null", + "5", + "null" + ).andOnSecondInstance("myField", + "2", + "null", + "3" + ).whenQuery("select sum(myField) from testTable") + .thenResultIs("DOUBLE", "10"); + } + + @Test(dataProvider = "scenarios") + void aggregationGroupBySVWithNullHandlingDisabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(false, FieldSpec.FieldType.METRIC) + .onFirstInstance("myField", + "5", + "null", + "3" + ).andOnSecondInstance("myField", + "null", + "2", + "null" + ).whenQuery("select 'literal', sum(myField) from testTable group by 'literal'") + .thenResultIs("STRING | DOUBLE", "literal | 10"); + } + + @Test(dataProvider = "scenarios") + void aggregationGroupBySVWithNullHandlingEnabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(true, FieldSpec.FieldType.METRIC) + .onFirstInstance("myField", + "5", + "null", + "3" + ).andOnSecondInstance("myField", + "null", + "null", + "null" + ).whenQuery("select 'literal', sum(myField) from testTable group by 'literal'") + .thenResultIs("STRING | DOUBLE", "literal | 8"); + } +} diff --git a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/SumPrecisionAggregationFunctionTest.java b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/SumPrecisionAggregationFunctionTest.java new file mode 100644 index 00000000000..76de0d07e28 --- /dev/null +++ b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/SumPrecisionAggregationFunctionTest.java @@ -0,0 +1,155 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.core.query.aggregation.function; + +import org.apache.pinot.spi.data.FieldSpec; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + + +public class SumPrecisionAggregationFunctionTest extends AbstractAggregationFunctionTest { + + @DataProvider(name = "scenarios") + Object[] scenarios() { + return new Object[] { + new DataTypeScenario(FieldSpec.DataType.INT), + new DataTypeScenario(FieldSpec.DataType.LONG), + new DataTypeScenario(FieldSpec.DataType.FLOAT), + new DataTypeScenario(FieldSpec.DataType.DOUBLE), + new DataTypeScenario(FieldSpec.DataType.BIG_DECIMAL) + }; + } + + @Test(dataProvider = "scenarios") + void aggregationAllNullsWithNullHandlingDisabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(false, FieldSpec.FieldType.METRIC) + .onFirstInstance("myField", + "null", + "null" + ).andOnSecondInstance("myField", + "null" + ).whenQuery("select sumprecision(myField) from testTable") + .thenResultIs("STRING", + String.valueOf(FieldSpec.getDefaultNullValue(FieldSpec.FieldType.METRIC, scenario.getDataType(), null))); + } + + @Test(dataProvider = "scenarios") + void aggregationAllNullsWithNullHandlingEnabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(true, FieldSpec.FieldType.METRIC) + .onFirstInstance("myField", + "null", + "null" + ).andOnSecondInstance("myField", + "null" + ).whenQuery("select sumprecision(myField) from testTable") + .thenResultIs("STRING", "null"); + } + + @Test(dataProvider = "scenarios") + void aggregationGroupBySVAllNullsWithNullHandlingDisabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(false, FieldSpec.FieldType.METRIC) + .onFirstInstance("myField", + "null", + "null" + ).andOnSecondInstance("myField", + "null" + ).whenQuery("select 'literal', sumprecision(myField) from testTable group by 'literal'") + .thenResultIs("STRING | STRING", "literal | " + + FieldSpec.getDefaultNullValue(FieldSpec.FieldType.METRIC, scenario.getDataType(), null)); + } + + @Test(dataProvider = "scenarios") + void aggregationGroupBySVAllNullsWithNullHandlingEnabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(true, FieldSpec.FieldType.METRIC) + .onFirstInstance("myField", + "null", + "null" + ).andOnSecondInstance("myField", + "null" + ).whenQuery("select 'literal', sumprecision(myField) from testTable group by 'literal'") + .thenResultIs("STRING | STRING", "literal | null"); + } + + @Test(dataProvider = "scenarios") + void aggregationWithNullHandlingDisabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(false, FieldSpec.FieldType.METRIC) + .onFirstInstance("myField", + "3", + "null", + "5" + ).andOnSecondInstance("myField", + "null", + "null" + ).whenQuery("select sumprecision(myField) from testTable") + .thenResultIs("STRING", getStringValueOfSum(8, scenario.getDataType())); + } + + @Test(dataProvider = "scenarios") + void aggregationWithNullHandlingEnabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(true, FieldSpec.FieldType.METRIC) + .onFirstInstance("myField", + "null", + "5", + "null" + ).andOnSecondInstance("myField", + "2", + "null", + "3" + ).whenQuery("select sumprecision(myField) from testTable") + .thenResultIs("STRING", getStringValueOfSum(10, scenario.getDataType())); + } + + @Test(dataProvider = "scenarios") + void aggregationGroupBySVWithNullHandlingDisabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(false, FieldSpec.FieldType.METRIC) + .onFirstInstance("myField", + "5", + "null", + "3" + ).andOnSecondInstance("myField", + "null", + "2", + "null" + ).whenQuery("select 'literal', sumprecision(myField) from testTable group by 'literal'") + .thenResultIs("STRING | STRING", "literal | " + getStringValueOfSum(10, scenario.getDataType())); + } + + @Test(dataProvider = "scenarios") + void aggregationGroupBySVWithNullHandlingEnabled(DataTypeScenario scenario) { + scenario.getDeclaringTable(true, FieldSpec.FieldType.METRIC) + .onFirstInstance("myField", + "5", + "null", + "3" + ).andOnSecondInstance("myField", + "null", + "null", + "null" + ).whenQuery("select 'literal', sumprecision(myField) from testTable group by 'literal'") + .thenResultIs("STRING | STRING", "literal | " + getStringValueOfSum(8, scenario.getDataType())); + } + + private String getStringValueOfSum(int sum, FieldSpec.DataType dataType) { + if (dataType == FieldSpec.DataType.FLOAT || dataType == FieldSpec.DataType.DOUBLE) { + return String.valueOf((double) sum); + } else { + return String.valueOf(sum); + } + } +} diff --git a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/VarianceAggregationFunctionTest.java b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/VarianceAggregationFunctionTest.java new file mode 100644 index 00000000000..0547d8e5aa2 --- /dev/null +++ b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/VarianceAggregationFunctionTest.java @@ -0,0 +1,168 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.core.query.aggregation.function; + +import java.util.EnumSet; +import java.util.List; +import java.util.Set; +import org.apache.pinot.segment.spi.AggregationFunctionType; +import org.apache.pinot.spi.data.FieldSpec; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import static org.apache.pinot.core.query.aggregation.utils.StatisticalAggregationFunctionUtils.calculateVariance; + + +public class VarianceAggregationFunctionTest extends AbstractAggregationFunctionTest { + + private static final EnumSet VARIANCE_FUNCTIONS = EnumSet.of(AggregationFunctionType.VARPOP, + AggregationFunctionType.VARSAMP, AggregationFunctionType.STDDEVPOP, AggregationFunctionType.STDDEVSAMP); + + private static final Set DATA_TYPES = Set.of(FieldSpec.DataType.INT, FieldSpec.DataType.LONG, + FieldSpec.DataType.FLOAT, FieldSpec.DataType.DOUBLE); + + @DataProvider(name = "scenarios") + Object[][] scenarios() { + Object[][] scenarios = new Object[16][2]; + + int i = 0; + for (AggregationFunctionType functionType : VARIANCE_FUNCTIONS) { + for (FieldSpec.DataType dataType : DATA_TYPES) { + scenarios[i][0] = functionType; + scenarios[i][1] = new DataTypeScenario(dataType); + i++; + } + } + + return scenarios; + } + + @Test(dataProvider = "scenarios") + void aggregationAllNullsWithNullHandlingDisabled(AggregationFunctionType functionType, DataTypeScenario scenario) { + scenario.getDeclaringTable(false, FieldSpec.FieldType.METRIC) + .onFirstInstance("myField", + "null", + "null" + ).andOnSecondInstance("myField", + "null" + ).whenQuery("select " + functionType.getName() + "(myField) from testTable") + .thenResultIs("DOUBLE", "0.0"); + } + + @Test(dataProvider = "scenarios") + void aggregationAllNullsWithNullHandlingEnabled(AggregationFunctionType functionType, DataTypeScenario scenario) { + scenario.getDeclaringTable(true, FieldSpec.FieldType.METRIC) + .onFirstInstance("myField", + "null", + "null" + ).andOnSecondInstance("myField", + "null" + ).whenQuery("select " + functionType.getName() + "(myField) from testTable") + .thenResultIs("DOUBLE", "null"); + } + + @Test(dataProvider = "scenarios") + void aggregationGroupBySVAllNullsWithNullHandlingDisabled(AggregationFunctionType functionType, + DataTypeScenario scenario) { + scenario.getDeclaringTable(false, FieldSpec.FieldType.METRIC) + .onFirstInstance("myField", + "null", + "null" + ).andOnSecondInstance("myField", + "null" + ).whenQuery("select 'literal', " + functionType.getName() + "(myField) from testTable group by 'literal'") + .thenResultIs("STRING | DOUBLE", "literal | 0.0"); + } + + @Test(dataProvider = "scenarios") + void aggregationGroupBySVAllNullsWithNullHandlingEnabled(AggregationFunctionType functionType, + DataTypeScenario scenario) { + scenario.getDeclaringTable(true, FieldSpec.FieldType.METRIC) + .onFirstInstance("myField", + "null", + "null" + ).andOnSecondInstance("myField", + "null" + ).whenQuery("select 'literal', " + functionType.getName() + "(myField) from testTable group by 'literal'") + .thenResultIs("STRING | DOUBLE", "literal | null"); + } + + @Test(dataProvider = "scenarios") + void aggregationWithNullHandlingDisabled(AggregationFunctionType functionType, DataTypeScenario scenario) { + scenario.getDeclaringTable(false, FieldSpec.FieldType.METRIC) + .onFirstInstance("myField", + "1", + "null", + "2" + ).andOnSecondInstance("myField", + "3", + "6", + "null" + ).whenQuery("select " + functionType.getName() + "(myField) from testTable") + .thenResultIs("DOUBLE", String.valueOf(calculateVariance(List.of(1.0, 0.0, 2.0, 3.0, 6.0, 0.0), + functionType))); + } + + @Test(dataProvider = "scenarios") + void aggregationWithNullHandlingEnabled(AggregationFunctionType functionType, DataTypeScenario scenario) { + scenario.getDeclaringTable(true, FieldSpec.FieldType.METRIC) + .onFirstInstance("myField", + "1", + "null", + "2" + ).andOnSecondInstance("myField", + "3", + "6", + "null" + ).whenQuery("select " + functionType.getName() + "(myField) from testTable") + .thenResultIs("DOUBLE", String.valueOf(calculateVariance(List.of(1.0, 2.0, 3.0, 6.0), functionType))); + } + + @Test(dataProvider = "scenarios") + void aggregationGroupBySVWithNullHandlingDisabled(AggregationFunctionType functionType, DataTypeScenario scenario) { + scenario.getDeclaringTable(false, FieldSpec.FieldType.METRIC) + .onFirstInstance("myField", + "1", + "null", + "2" + ).andOnSecondInstance("myField", + "3", + "6", + "null" + ).whenQuery("select 'literal', " + functionType.getName() + "(myField) from testTable group by 'literal'") + .thenResultIs("STRING | DOUBLE", "literal | " + + calculateVariance(List.of(1.0, 0.0, 2.0, 3.0, 6.0, 0.0), functionType)); + } + + @Test(dataProvider = "scenarios") + void aggregationGroupBySVWithNullHandlingEnabled(AggregationFunctionType functionType, DataTypeScenario scenario) { + scenario.getDeclaringTable(true, FieldSpec.FieldType.METRIC) + .onFirstInstance("myField", + "1", + "null", + "2" + ).andOnSecondInstance("myField", + "3", + "6", + "null" + ).whenQuery("select 'literal', " + functionType.getName() + "(myField) from testTable group by 'literal'") + .thenResultIs("STRING | DOUBLE", "literal | " + + calculateVariance(List.of(1.0, 2.0, 3.0, 6.0), functionType)); + } +} diff --git a/pinot-core/src/test/java/org/apache/pinot/queries/BaseQueriesTest.java b/pinot-core/src/test/java/org/apache/pinot/queries/BaseQueriesTest.java index a4710740184..af7b88753dd 100644 --- a/pinot-core/src/test/java/org/apache/pinot/queries/BaseQueriesTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/queries/BaseQueriesTest.java @@ -77,6 +77,10 @@ public abstract class BaseQueriesTest { protected static final ExecutorService EXECUTOR_SERVICE = Executors.newFixedThreadPool(2); protected static final BrokerMetrics BROKER_METRICS = mock(BrokerMetrics.class); + public final void shutdownExecutor() { + EXECUTOR_SERVICE.shutdownNow(); + } + @Language(value = "sql", prefix = "select * from table") protected abstract String getFilter(); diff --git a/pinot-core/src/test/java/org/apache/pinot/queries/FluentQueryTest.java b/pinot-core/src/test/java/org/apache/pinot/queries/FluentQueryTest.java index f0dae6a2f2e..f3982e65e51 100644 --- a/pinot-core/src/test/java/org/apache/pinot/queries/FluentQueryTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/queries/FluentQueryTest.java @@ -102,6 +102,10 @@ public static class DeclaringTable { _extraQueryOptions = extraQueryOptions; } + public OnFirstInstance getFirstInstance() { + return new OnFirstInstance(_tableConfig, _schema, _baseDir, false, _baseQueriesTest, _extraQueryOptions); + } + public OnFirstInstance onFirstInstance(String... content) { return new OnFirstInstance(_tableConfig, _schema, _baseDir, false, _baseQueriesTest, _extraQueryOptions) .andSegment(content); @@ -220,6 +224,13 @@ public OnFirstInstance andSegment(String... tableText) { return this; } + public OnSecondInstance getSecondInstance() { + processSegments(); + return new OnSecondInstance( + _tableConfig, _schema, _indexDir.getParentFile(), !_onSecondInstance, _baseQueriesTest, _extraQueryOptions + ); + } + public OnSecondInstance andOnSecondInstance(Object[]... content) { processSegments(); return new OnSecondInstance( @@ -233,6 +244,15 @@ public OnSecondInstance andOnSecondInstance(String... content) { _tableConfig, _schema, _indexDir.getParentFile(), !_onSecondInstance, _baseQueriesTest, _extraQueryOptions) .andSegment(content); } + + public OnFirstInstance prepareToQuery() { + processSegments(); + return this; + } + + public void tearDown() { + _baseQueriesTest.shutdownExecutor(); + } } public static class OnSecondInstance extends TableWithSegments { @@ -250,6 +270,15 @@ public OnSecondInstance andSegment(String... tableText) { super.andSegment(tableText); return this; } + + public OnSecondInstance prepareToQuery() { + processSegments(); + return this; + } + + public void tearDown() { + _baseQueriesTest.shutdownExecutor(); + } } public static class QueryExecuted { diff --git a/pinot-perf/src/main/java/org/apache/pinot/perf/SyntheticBlockValSets.java b/pinot-perf/src/main/java/org/apache/pinot/perf/SyntheticBlockValSets.java index 0214a82c2f0..798d6ad8e34 100644 --- a/pinot-perf/src/main/java/org/apache/pinot/perf/SyntheticBlockValSets.java +++ b/pinot-perf/src/main/java/org/apache/pinot/perf/SyntheticBlockValSets.java @@ -240,7 +240,7 @@ public RoaringBitmap getNullBitmap() { @Override public FieldSpec.DataType getValueType() { - return FieldSpec.DataType.LONG; + return FieldSpec.DataType.DOUBLE; } @Override diff --git a/pinot-perf/src/main/java/org/apache/pinot/perf/AbstractAggregationFunctionBenchmark.java b/pinot-perf/src/main/java/org/apache/pinot/perf/aggregation/AbstractAggregationFunctionBenchmark.java similarity index 91% rename from pinot-perf/src/main/java/org/apache/pinot/perf/AbstractAggregationFunctionBenchmark.java rename to pinot-perf/src/main/java/org/apache/pinot/perf/aggregation/AbstractAggregationFunctionBenchmark.java index e3909278422..aa7354e33e2 100644 --- a/pinot-perf/src/main/java/org/apache/pinot/perf/AbstractAggregationFunctionBenchmark.java +++ b/pinot-perf/src/main/java/org/apache/pinot/perf/aggregation/AbstractAggregationFunctionBenchmark.java @@ -16,7 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.pinot.perf; +package org.apache.pinot.perf.aggregation; import com.google.common.base.Preconditions; import java.util.Map; @@ -30,7 +30,9 @@ import org.openjdk.jmh.annotations.Setup; import org.openjdk.jmh.infra.Blackhole; - +/** + * Base class for aggregation function benchmarks. + */ public abstract class AbstractAggregationFunctionBenchmark { /** @@ -69,12 +71,25 @@ public abstract class AbstractAggregationFunctionBenchmark { */ protected abstract Map getBlockValSetMap(); + /** + * Returns the comparable final result extracted from the aggregation result holder. + *

    + * This method will be called in the benchmark method, so it must be fast. + */ + protected Comparable extractFinalResult(AggregationResultHolder aggregationResultHolder) { + return getAggregationFunction().extractFinalResult(aggregationResultHolder.getResult()); + } + /** * Verifies the final result of the aggregation function. * * This method will be called in the benchmark method, so it must be fast. */ protected void verifyResult(Blackhole bh, Comparable finalResult, Object expectedResult) { + if (expectedResult == null) { + Preconditions.checkArgument(finalResult == null, "Expected final result to be null, actual: %s", finalResult); + return; + } Preconditions.checkState(finalResult.equals(expectedResult), "Result mismatch: expected: %s, actual: %s", expectedResult, finalResult); bh.consume(finalResult); @@ -211,7 +226,7 @@ public void test(Blackhole bh) { getAggregationFunction().aggregate(DocIdSetPlanNode.MAX_DOC_PER_CALL, resultHolder, blockValSetMap); - Comparable finalResult = getAggregationFunction().extractFinalResult(resultHolder.getResult()); + Comparable finalResult = extractFinalResult(resultHolder); verifyResult(bh, finalResult, getExpectedResult()); } diff --git a/pinot-perf/src/main/java/org/apache/pinot/perf/aggregation/AbstractAggregationQueryBenchmark.java b/pinot-perf/src/main/java/org/apache/pinot/perf/aggregation/AbstractAggregationQueryBenchmark.java new file mode 100644 index 00000000000..7e8a3ded162 --- /dev/null +++ b/pinot-perf/src/main/java/org/apache/pinot/perf/aggregation/AbstractAggregationQueryBenchmark.java @@ -0,0 +1,88 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.perf.aggregation; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.util.List; +import org.apache.commons.io.FileUtils; +import org.apache.pinot.queries.FluentQueryTest; +import org.apache.pinot.spi.config.table.TableConfig; +import org.apache.pinot.spi.data.Schema; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.infra.Blackhole; + +/** + * Base class for aggregation query benchmarks. + */ +public abstract class AbstractAggregationQueryBenchmark { + + private File _baseDir; + private FluentQueryTest.OnSecondInstance _onSecondInstance; + + protected void init(boolean nullHandlingEnabled) throws IOException { + _baseDir = Files.createTempDirectory(getClass().getSimpleName()).toFile(); + TableConfig tableConfig = createTableConfig(); + Schema schema = createSchema(); + List> segmentsPerServer = createSegmentsPerServer(); + + FluentQueryTest.OnFirstInstance onFirstInstance = + FluentQueryTest.withBaseDir(_baseDir) + .withNullHandling(nullHandlingEnabled) + .givenTable(schema, tableConfig) + .getFirstInstance(); + + List segmentsOnFirstServer = segmentsPerServer.get(0); + for (Object[][] segment : segmentsOnFirstServer) { + onFirstInstance.andSegment(segment); + } + + FluentQueryTest.OnSecondInstance onSecondInstance = onFirstInstance.getSecondInstance(); + List segmentsOnSecondServer = segmentsPerServer.get(1); + for (Object[][] segment : segmentsOnSecondServer) { + onSecondInstance.andSegment(segment); + } + onSecondInstance.prepareToQuery(); + _onSecondInstance = onSecondInstance; + } + + @TearDown(Level.Trial) + public void tearDown() throws IOException { + if (_baseDir != null) { + FileUtils.deleteDirectory(_baseDir); + } + _onSecondInstance.tearDown(); + } + + protected void executeQuery(String query, Blackhole bh) { + bh.consume(_onSecondInstance.whenQuery(query)); + } + + protected abstract Schema createSchema(); + + protected abstract TableConfig createTableConfig(); + + /** + * Returns a list of segments to be created on the servers. The first list is the list of segments to be + * created on the first server and the second list is the segments to be created on the second server. + */ + protected abstract List> createSegmentsPerServer(); +} diff --git a/pinot-perf/src/main/java/org/apache/pinot/perf/aggregation/BenchmarkAvgAggregation.java b/pinot-perf/src/main/java/org/apache/pinot/perf/aggregation/BenchmarkAvgAggregation.java new file mode 100644 index 00000000000..78e92acaeb6 --- /dev/null +++ b/pinot-perf/src/main/java/org/apache/pinot/perf/aggregation/BenchmarkAvgAggregation.java @@ -0,0 +1,121 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.perf.aggregation; + +import java.util.Collections; +import java.util.Map; +import java.util.Random; +import java.util.concurrent.TimeUnit; +import org.apache.pinot.common.request.context.ExpressionContext; +import org.apache.pinot.core.common.BlockValSet; +import org.apache.pinot.core.plan.DocIdSetPlanNode; +import org.apache.pinot.core.query.aggregation.AggregationResultHolder; +import org.apache.pinot.core.query.aggregation.function.AggregationFunction; +import org.apache.pinot.core.query.aggregation.function.AvgAggregationFunction; +import org.apache.pinot.perf.SyntheticBlockValSets; +import org.apache.pinot.perf.SyntheticNullBitmapFactories; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.RunnerException; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; +import org.roaringbitmap.RoaringBitmap; + + +@Fork(1) +@BenchmarkMode(Mode.Throughput) +@Warmup(iterations = 50, time = 100, timeUnit = TimeUnit.MILLISECONDS) +@Measurement(iterations = 50, time = 100, timeUnit = TimeUnit.MILLISECONDS) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@State(Scope.Benchmark) +public class BenchmarkAvgAggregation extends AbstractAggregationFunctionBenchmark.Stable { + private static final ExpressionContext EXPR = ExpressionContext.forIdentifier("col"); + + @Param({"false", "true"}) + private boolean _nullHandlingEnabled; + + @Param({"1", "2", "4", "8", "16", "32", "64", "128"}) + protected int _nullPeriod; + + public static void main(String[] args) throws RunnerException { + Options opt = new OptionsBuilder().include(BenchmarkAvgAggregation.class.getSimpleName()).build(); + new Runner(opt).run(); + } + + @Override + protected AggregationFunction createAggregationFunction() { + return new AvgAggregationFunction(Collections.singletonList(EXPR), _nullHandlingEnabled); + } + + @Override + protected AggregationResultHolder createResultHolder() { + return getAggregationFunction().createAggregationResultHolder(); + } + + @Override + protected Map createBlockValSetMap() { + Random valueRandom = new Random(420); + int numDocs = DocIdSetPlanNode.MAX_DOC_PER_CALL; + RoaringBitmap nullBitmap = SyntheticNullBitmapFactories.Periodic.randomInPeriod(numDocs, _nullPeriod); + BlockValSet block = SyntheticBlockValSets.Double.create(numDocs, _nullHandlingEnabled ? nullBitmap : null, + valueRandom::nextDouble); + return Map.of(EXPR, block); + } + + @Override + protected Object createExpectedResult(Map map) { + double sum = 0.0; + long count = 0; + + BlockValSet blockValSet = getBlockValSetMap().get(EXPR); + double[] doubleValuesSV = blockValSet.getDoubleValuesSV(); + RoaringBitmap nullBitmap = blockValSet.getNullBitmap(); + + for (int i = 0; i < doubleValuesSV.length; i++) { + if (nullBitmap != null && nullBitmap.contains(i)) { + continue; + } + sum += doubleValuesSV[i]; + count++; + } + + if (count != 0) { + return sum / count; + } else { + if (_nullHandlingEnabled) { + return null; + } else { + return Double.NEGATIVE_INFINITY; + } + } + } + + @Override + protected void resetResultHolder(AggregationResultHolder resultHolder) { + resultHolder.setValue(null); + } +} diff --git a/pinot-perf/src/main/java/org/apache/pinot/perf/aggregation/BenchmarkDistinctCountAggregation.java b/pinot-perf/src/main/java/org/apache/pinot/perf/aggregation/BenchmarkDistinctCountAggregation.java new file mode 100644 index 00000000000..dd2da3865d6 --- /dev/null +++ b/pinot-perf/src/main/java/org/apache/pinot/perf/aggregation/BenchmarkDistinctCountAggregation.java @@ -0,0 +1,107 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.perf.aggregation; + +import java.util.Collections; +import java.util.Map; +import java.util.Random; +import java.util.concurrent.TimeUnit; +import java.util.stream.IntStream; +import org.apache.pinot.common.request.context.ExpressionContext; +import org.apache.pinot.core.common.BlockValSet; +import org.apache.pinot.core.plan.DocIdSetPlanNode; +import org.apache.pinot.core.query.aggregation.AggregationResultHolder; +import org.apache.pinot.core.query.aggregation.function.AggregationFunction; +import org.apache.pinot.core.query.aggregation.function.DistinctCountAggregationFunction; +import org.apache.pinot.perf.SyntheticBlockValSets; +import org.apache.pinot.perf.SyntheticNullBitmapFactories; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.RunnerException; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; +import org.roaringbitmap.RoaringBitmap; + + +@Fork(1) +@BenchmarkMode(Mode.Throughput) +@Warmup(iterations = 50, time = 100, timeUnit = TimeUnit.MILLISECONDS) +@Measurement(iterations = 50, time = 100, timeUnit = TimeUnit.MILLISECONDS) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@State(Scope.Benchmark) +public class BenchmarkDistinctCountAggregation extends AbstractAggregationFunctionBenchmark.Stable { + private static final ExpressionContext EXPR = ExpressionContext.forIdentifier("col"); + + @Param({"false", "true"}) + private boolean _nullHandlingEnabled; + + @Param({"1", "2", "4", "8", "16", "32", "64", "128"}) + protected int _nullPeriod; + + public static void main(String[] args) throws RunnerException { + Options opt = new OptionsBuilder().include(BenchmarkDistinctCountAggregation.class.getSimpleName()).build(); + new Runner(opt).run(); + } + + @Override + protected AggregationFunction createAggregationFunction() { + return new DistinctCountAggregationFunction(Collections.singletonList(EXPR), _nullHandlingEnabled); + } + + @Override + protected AggregationResultHolder createResultHolder() { + return getAggregationFunction().createAggregationResultHolder(); + } + + @Override + protected Map createBlockValSetMap() { + Random valueRandom = new Random(420); + int numDocs = DocIdSetPlanNode.MAX_DOC_PER_CALL; + RoaringBitmap nullBitmap = SyntheticNullBitmapFactories.Periodic.randomInPeriod(numDocs, _nullPeriod); + BlockValSet block = SyntheticBlockValSets.Long.create(numDocs, _nullHandlingEnabled ? nullBitmap : null, + valueRandom::nextLong); + return Map.of(EXPR, block); + } + + @Override + protected Object createExpectedResult(Map map) { + BlockValSet blockValSet = getBlockValSetMap().get(EXPR); + long[] longValuesSV = blockValSet.getLongValuesSV(); + RoaringBitmap nullBitmap = blockValSet.getNullBitmap(); + + return (int) IntStream.range(0, longValuesSV.length) + .filter(i -> nullBitmap == null || !nullBitmap.contains(i)) + .mapToLong(i -> longValuesSV[i]) + .distinct() + .count(); + } + + @Override + protected void resetResultHolder(AggregationResultHolder resultHolder) { + resultHolder.setValue(null); + } +} diff --git a/pinot-perf/src/main/java/org/apache/pinot/perf/aggregation/BenchmarkMinAggregation.java b/pinot-perf/src/main/java/org/apache/pinot/perf/aggregation/BenchmarkMinAggregation.java new file mode 100644 index 00000000000..9bf6eb07e94 --- /dev/null +++ b/pinot-perf/src/main/java/org/apache/pinot/perf/aggregation/BenchmarkMinAggregation.java @@ -0,0 +1,123 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.perf.aggregation; + +import java.util.Collections; +import java.util.Map; +import java.util.Random; +import java.util.concurrent.TimeUnit; +import org.apache.pinot.common.request.context.ExpressionContext; +import org.apache.pinot.core.common.BlockValSet; +import org.apache.pinot.core.plan.DocIdSetPlanNode; +import org.apache.pinot.core.query.aggregation.AggregationResultHolder; +import org.apache.pinot.core.query.aggregation.function.AggregationFunction; +import org.apache.pinot.core.query.aggregation.function.MinAggregationFunction; +import org.apache.pinot.perf.SyntheticBlockValSets; +import org.apache.pinot.perf.SyntheticNullBitmapFactories; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.RunnerException; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; +import org.roaringbitmap.RoaringBitmap; + + +@Fork(1) +@BenchmarkMode(Mode.Throughput) +@Warmup(iterations = 50, time = 100, timeUnit = TimeUnit.MILLISECONDS) +@Measurement(iterations = 50, time = 100, timeUnit = TimeUnit.MILLISECONDS) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@State(Scope.Benchmark) +public class BenchmarkMinAggregation extends AbstractAggregationFunctionBenchmark.Stable { + private static final ExpressionContext EXPR = ExpressionContext.forIdentifier("col"); + + @Param({"false", "true"}) + private boolean _nullHandlingEnabled; + + @Param({"1", "2", "4", "8", "16", "32", "64", "128"}) + protected int _nullPeriod; + + public static void main(String[] args) throws RunnerException { + Options opt = new OptionsBuilder().include(BenchmarkMinAggregation.class.getSimpleName()).build(); + new Runner(opt).run(); + } + + @Override + protected AggregationFunction createAggregationFunction() { + return new MinAggregationFunction(Collections.singletonList(EXPR), _nullHandlingEnabled); + } + + @Override + protected AggregationResultHolder createResultHolder() { + return getAggregationFunction().createAggregationResultHolder(); + } + + @Override + protected Map createBlockValSetMap() { + Random valueRandom = new Random(420); + int numDocs = DocIdSetPlanNode.MAX_DOC_PER_CALL; + RoaringBitmap nullBitmap = SyntheticNullBitmapFactories.Periodic.randomInPeriod(numDocs, _nullPeriod); + BlockValSet block = SyntheticBlockValSets.Double.create(numDocs, _nullHandlingEnabled ? nullBitmap : null, + valueRandom::nextInt); + return Map.of(EXPR, block); + } + + @Override + protected Object createExpectedResult(Map map) { + Double min = null; + BlockValSet blockValSet = getBlockValSetMap().get(EXPR); + double[] doubleValuesSV = blockValSet.getDoubleValuesSV(); + RoaringBitmap nullBitmap = blockValSet.getNullBitmap(); + + for (int i = 0; i < doubleValuesSV.length; i++) { + if (nullBitmap != null && nullBitmap.contains(i)) { + continue; + } + min = (min == null) ? doubleValuesSV[i] : Math.min(min, doubleValuesSV[i]); + } + + return min; + } + + @Override + protected void resetResultHolder(AggregationResultHolder resultHolder) { + if (_nullHandlingEnabled) { + resultHolder.setValue(null); + } else { + resultHolder.setValue(Double.POSITIVE_INFINITY); + } + } + + @Override + protected Comparable extractFinalResult(AggregationResultHolder resultHolder) { + if (_nullHandlingEnabled) { + return resultHolder.getResult(); + } else { + return resultHolder.getDoubleResult(); + } + } +} diff --git a/pinot-perf/src/main/java/org/apache/pinot/perf/BenchmarkModeAggregation.java b/pinot-perf/src/main/java/org/apache/pinot/perf/aggregation/BenchmarkModeAggregation.java similarity index 97% rename from pinot-perf/src/main/java/org/apache/pinot/perf/BenchmarkModeAggregation.java rename to pinot-perf/src/main/java/org/apache/pinot/perf/aggregation/BenchmarkModeAggregation.java index dcfff95c417..5a900ee3885 100644 --- a/pinot-perf/src/main/java/org/apache/pinot/perf/BenchmarkModeAggregation.java +++ b/pinot-perf/src/main/java/org/apache/pinot/perf/aggregation/BenchmarkModeAggregation.java @@ -16,7 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.pinot.perf; +package org.apache.pinot.perf.aggregation; import java.util.Collections; import java.util.Comparator; @@ -31,6 +31,8 @@ import org.apache.pinot.core.query.aggregation.AggregationResultHolder; import org.apache.pinot.core.query.aggregation.function.AggregationFunction; import org.apache.pinot.core.query.aggregation.function.ModeAggregationFunction; +import org.apache.pinot.perf.SyntheticBlockValSets; +import org.apache.pinot.perf.SyntheticNullBitmapFactories; import org.openjdk.jmh.annotations.BenchmarkMode; import org.openjdk.jmh.annotations.Fork; import org.openjdk.jmh.annotations.Level; diff --git a/pinot-perf/src/main/java/org/apache/pinot/perf/aggregation/BenchmarkSumAggregation.java b/pinot-perf/src/main/java/org/apache/pinot/perf/aggregation/BenchmarkSumAggregation.java new file mode 100644 index 00000000000..00b9381c754 --- /dev/null +++ b/pinot-perf/src/main/java/org/apache/pinot/perf/aggregation/BenchmarkSumAggregation.java @@ -0,0 +1,127 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.perf.aggregation; + +import java.util.Collections; +import java.util.Map; +import java.util.Random; +import java.util.concurrent.TimeUnit; +import org.apache.pinot.common.request.context.ExpressionContext; +import org.apache.pinot.core.common.BlockValSet; +import org.apache.pinot.core.plan.DocIdSetPlanNode; +import org.apache.pinot.core.query.aggregation.AggregationResultHolder; +import org.apache.pinot.core.query.aggregation.function.AggregationFunction; +import org.apache.pinot.core.query.aggregation.function.SumAggregationFunction; +import org.apache.pinot.perf.SyntheticBlockValSets; +import org.apache.pinot.perf.SyntheticNullBitmapFactories; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.RunnerException; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; +import org.roaringbitmap.RoaringBitmap; + + +@Fork(1) +@BenchmarkMode(Mode.Throughput) +@Warmup(iterations = 50, time = 100, timeUnit = TimeUnit.MILLISECONDS) +@Measurement(iterations = 50, time = 100, timeUnit = TimeUnit.MILLISECONDS) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@State(Scope.Benchmark) +public class BenchmarkSumAggregation extends AbstractAggregationFunctionBenchmark.Stable { + private static final ExpressionContext EXPR = ExpressionContext.forIdentifier("col"); + + @Param({"false", "true"}) + private boolean _nullHandlingEnabled; + + @Param({"1", "2", "4", "8", "16", "32", "64", "128"}) + protected int _nullPeriod; + + public static void main(String[] args) throws RunnerException { + Options opt = new OptionsBuilder().include(BenchmarkSumAggregation.class.getSimpleName()).build(); + new Runner(opt).run(); + } + + @Override + protected AggregationFunction createAggregationFunction() { + return new SumAggregationFunction(Collections.singletonList(EXPR), _nullHandlingEnabled); + } + + @Override + protected AggregationResultHolder createResultHolder() { + return getAggregationFunction().createAggregationResultHolder(); + } + + @Override + protected Map createBlockValSetMap() { + Random valueRandom = new Random(420); + int numDocs = DocIdSetPlanNode.MAX_DOC_PER_CALL; + RoaringBitmap nullBitmap = SyntheticNullBitmapFactories.Periodic.randomInPeriod(numDocs, _nullPeriod); + BlockValSet block = SyntheticBlockValSets.Long.create(numDocs, _nullHandlingEnabled ? nullBitmap : null, + valueRandom::nextInt); + return Map.of(EXPR, block); + } + + @Override + protected Object createExpectedResult(Map map) { + Double sum = null; + BlockValSet blockValSet = getBlockValSetMap().get(EXPR); + long[] longValuesSV = blockValSet.getLongValuesSV(); + RoaringBitmap nullBitmap = blockValSet.getNullBitmap(); + + for (int i = 0; i < longValuesSV.length; i++) { + if (nullBitmap != null && nullBitmap.contains(i)) { + continue; + } + sum = (sum == null) ? longValuesSV[i] : sum + longValuesSV[i]; + } + + if (!_nullHandlingEnabled && sum == null) { + return 0.0d; + } else { + return sum; + } + } + + @Override + protected void resetResultHolder(AggregationResultHolder resultHolder) { + if (_nullHandlingEnabled) { + resultHolder.setValue(null); + } else { + resultHolder.setValue(0.0); + } + } + + @Override + protected Comparable extractFinalResult(AggregationResultHolder resultHolder) { + if (_nullHandlingEnabled) { + return resultHolder.getResult(); + } else { + return resultHolder.getDoubleResult(); + } + } +} diff --git a/pinot-perf/src/main/java/org/apache/pinot/perf/aggregation/BenchmarkSumQuery.java b/pinot-perf/src/main/java/org/apache/pinot/perf/aggregation/BenchmarkSumQuery.java new file mode 100644 index 00000000000..eb3e4fb7872 --- /dev/null +++ b/pinot-perf/src/main/java/org/apache/pinot/perf/aggregation/BenchmarkSumQuery.java @@ -0,0 +1,121 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.perf.aggregation; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; +import java.util.concurrent.TimeUnit; +import org.apache.pinot.spi.config.table.TableConfig; +import org.apache.pinot.spi.config.table.TableType; +import org.apache.pinot.spi.data.FieldSpec; +import org.apache.pinot.spi.data.Schema; +import org.apache.pinot.spi.utils.builder.TableConfigBuilder; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.RunnerException; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; + + +@Fork(1) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.MICROSECONDS) +@Warmup(iterations = 10, time = 1) +@Measurement(iterations = 10, time = 1) +@State(Scope.Benchmark) +public class BenchmarkSumQuery extends AbstractAggregationQueryBenchmark { + + @Param({"false", "true"}) + public boolean _nullHandling; + + @Param({"1", "2", "4", "8", "16", "32", "64", "128"}) + protected int _nullPeriod; + + public static void main(String[] args) throws RunnerException { + Options opt = new OptionsBuilder() + .include(BenchmarkSumQuery.class.getSimpleName()) + .build(); + + new Runner(opt).run(); + } + + @Override + protected Schema createSchema() { + return new Schema.SchemaBuilder() + .setSchemaName("benchmark") + .addMetricField("col", FieldSpec.DataType.INT) + .build(); + } + + @Override + protected TableConfig createTableConfig() { + return new TableConfigBuilder(TableType.OFFLINE) + .setTableName("benchmark") + .setNullHandlingEnabled(true) + .build(); + } + + @Override + protected List> createSegmentsPerServer() { + Random valueRandom = new Random(420); + List> segmentsPerServer = new ArrayList<>(); + segmentsPerServer.add(new ArrayList<>()); + segmentsPerServer.add(new ArrayList<>()); + + // 2 servers + for (int server = 0; server < 2; server++) { + List segments = segmentsPerServer.get(server); + // 3 segments per server + for (int seg = 0; seg < 3; seg++) { + // 10000 single column rows per segment + Object[][] segment = new Object[10000][1]; + for (int row = 0; row < 10000; row++) { + segment[row][0] = (row % _nullPeriod) == 0 ? null : valueRandom.nextInt(); + } + segments.add(segment); + } + } + + return segmentsPerServer; + } + + @Setup(Level.Trial) + public void setup() throws IOException { + init(_nullHandling); + } + + @Benchmark + public void test(Blackhole bh) { + executeQuery("SELECT SUM(col) FROM mytable", bh); + } +} diff --git a/pinot-perf/src/main/java/org/apache/pinot/perf/aggregation/BenchmarkVarianceAggregation.java b/pinot-perf/src/main/java/org/apache/pinot/perf/aggregation/BenchmarkVarianceAggregation.java new file mode 100644 index 00000000000..744ef1afd66 --- /dev/null +++ b/pinot-perf/src/main/java/org/apache/pinot/perf/aggregation/BenchmarkVarianceAggregation.java @@ -0,0 +1,117 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.perf.aggregation; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import org.apache.pinot.common.request.context.ExpressionContext; +import org.apache.pinot.core.common.BlockValSet; +import org.apache.pinot.core.plan.DocIdSetPlanNode; +import org.apache.pinot.core.query.aggregation.AggregationResultHolder; +import org.apache.pinot.core.query.aggregation.function.AggregationFunction; +import org.apache.pinot.core.query.aggregation.function.VarianceAggregationFunction; +import org.apache.pinot.core.query.aggregation.utils.StatisticalAggregationFunctionUtils; +import org.apache.pinot.perf.SyntheticBlockValSets; +import org.apache.pinot.perf.SyntheticNullBitmapFactories; +import org.apache.pinot.segment.spi.AggregationFunctionType; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.RunnerException; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; +import org.roaringbitmap.RoaringBitmap; + + +@Fork(1) +@BenchmarkMode(Mode.Throughput) +@Warmup(iterations = 50, time = 100, timeUnit = TimeUnit.MILLISECONDS) +@Measurement(iterations = 50, time = 100, timeUnit = TimeUnit.MILLISECONDS) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@State(Scope.Benchmark) +public class BenchmarkVarianceAggregation extends AbstractAggregationFunctionBenchmark.Stable { + private static final ExpressionContext EXPR = ExpressionContext.forIdentifier("col"); + + @Param({"false", "true"}) + private boolean _nullHandlingEnabled; + + @Param({"1", "2", "4", "8", "16", "32", "64", "128"}) + protected int _nullPeriod; + + public static void main(String[] args) throws RunnerException { + Options opt = new OptionsBuilder().include(BenchmarkVarianceAggregation.class.getSimpleName()).build(); + new Runner(opt).run(); + } + + @Override + protected AggregationFunction createAggregationFunction() { + return new VarianceAggregationFunction(Collections.singletonList(EXPR), true, false, _nullHandlingEnabled); + } + + @Override + protected AggregationResultHolder createResultHolder() { + return getAggregationFunction().createAggregationResultHolder(); + } + + @Override + protected Map createBlockValSetMap() { + Random valueRandom = new Random(420); + int numDocs = DocIdSetPlanNode.MAX_DOC_PER_CALL; + RoaringBitmap nullBitmap = SyntheticNullBitmapFactories.Periodic.randomInPeriod(numDocs, _nullPeriod); + BlockValSet block = SyntheticBlockValSets.Double.create(numDocs, _nullHandlingEnabled ? nullBitmap : null, + valueRandom::nextInt); + return Map.of(EXPR, block); + } + + @Override + protected Object createExpectedResult(Map map) { + BlockValSet blockValSet = getBlockValSetMap().get(EXPR); + double[] doubleValuesSV = blockValSet.getDoubleValuesSV(); + RoaringBitmap nullBitmap = blockValSet.getNullBitmap(); + + List values = IntStream.range(0, doubleValuesSV.length) + .filter(i -> nullBitmap == null || !nullBitmap.contains(i)) + .mapToDouble(i -> doubleValuesSV[i]) + .boxed() + .collect(Collectors.toList()); + + Double variance = null; + if (!values.isEmpty()) { + variance = StatisticalAggregationFunctionUtils.calculateVariance(values, AggregationFunctionType.VARSAMP); + } + return variance; + } + + @Override + protected void resetResultHolder(AggregationResultHolder resultHolder) { + resultHolder.setValue(null); + } +} diff --git a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/customobject/AvgPair.java b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/customobject/AvgPair.java index eddb0162d8b..eaf5d5d744b 100644 --- a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/customobject/AvgPair.java +++ b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/customobject/AvgPair.java @@ -26,6 +26,10 @@ public class AvgPair implements Comparable { private double _sum; private long _count; + public AvgPair() { + this(0.0, 0L); + } + public AvgPair(double sum, long count) { _sum = sum; _count = count; diff --git a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/customobject/VarianceTuple.java b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/customobject/VarianceTuple.java index d1159002a18..49b2fc0b225 100644 --- a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/customobject/VarianceTuple.java +++ b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/customobject/VarianceTuple.java @@ -32,6 +32,15 @@ public VarianceTuple(long count, double sum, double m2) { _m2 = m2; } + public void apply(double value) { + _count++; + _sum += value; + if (_count > 1) { + double t = _count * value - _sum; + _m2 += (t * t) / (_count * (_count - 1)); + } + } + public void apply(long count, double sum, double m2) { if (count == 0) { return;