Skip to content

Commit

Permalink
Improve null handling performance for nullable single input aggregati…
Browse files Browse the repository at this point in the history
…on functions (apache#13791)

Modify aggregation functions that were not extending NullableSingleInputAggregationFunction to do so and optimize their code to use the methods included there.
  • Loading branch information
yashmayya authored Aug 30, 2024
1 parent fc132c3 commit cf52567
Show file tree
Hide file tree
Showing 56 changed files with 3,727 additions and 1,549 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,9 @@ public boolean equals(Object o) {
public int hashCode() {
return Arrays.hashCode(_values);
}

@Override
public String toString() {
return Arrays.toString(_values);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,9 @@ public boolean equals(Object o) {
public int hashCode() {
return Arrays.hashCode(_values);
}

@Override
public String toString() {
return Arrays.toString(_values);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,17 @@
import org.apache.pinot.segment.local.customobject.AvgPair;
import org.apache.pinot.segment.spi.AggregationFunctionType;
import org.apache.pinot.spi.data.FieldSpec.DataType;
import org.roaringbitmap.RoaringBitmap;


public class AvgAggregationFunction extends BaseSingleInputAggregationFunction<AvgPair, Double> {
public class AvgAggregationFunction extends NullableSingleInputAggregationFunction<AvgPair, Double> {
private static final double DEFAULT_FINAL_RESULT = Double.NEGATIVE_INFINITY;
private final boolean _nullHandlingEnabled;

public AvgAggregationFunction(List<ExpressionContext> arguments, boolean nullHandlingEnabled) {
this(verifySingleArgument(arguments, "AVG"), nullHandlingEnabled);
}

protected AvgAggregationFunction(ExpressionContext expression, boolean nullHandlingEnabled) {
super(expression);
_nullHandlingEnabled = nullHandlingEnabled;
super(expression, nullHandlingEnabled);
}

@Override
Expand All @@ -66,73 +63,37 @@ public GroupByResultHolder createGroupByResultHolder(int initialCapacity, int ma
public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
BlockValSet blockValSet = blockValSetMap.get(_expression);
if (_nullHandlingEnabled) {
RoaringBitmap nullBitmap = blockValSet.getNullBitmap();
if (nullBitmap != null && !nullBitmap.isEmpty()) {
aggregateNullHandlingEnabled(length, aggregationResultHolder, blockValSet, nullBitmap);
return;
}
}

if (blockValSet.getValueType() != DataType.BYTES) {
double[] doubleValues = blockValSet.getDoubleValuesSV();
double sum = 0.0;
for (int i = 0; i < length; i++) {
sum += doubleValues[i];
}
setAggregationResult(aggregationResultHolder, sum, length);
} else {
// Serialized AvgPair
byte[][] bytesValues = blockValSet.getBytesValuesSV();
double sum = 0.0;
long count = 0L;
for (int i = 0; i < length; i++) {
AvgPair value = ObjectSerDeUtils.AVG_PAIR_SER_DE.deserialize(bytesValues[i]);
sum += value.getSum();
count += value.getCount();
}
setAggregationResult(aggregationResultHolder, sum, count);
}
}

private void aggregateNullHandlingEnabled(int length, AggregationResultHolder aggregationResultHolder,
BlockValSet blockValSet, RoaringBitmap nullBitmap) {
if (blockValSet.getValueType() != DataType.BYTES) {
double[] doubleValues = blockValSet.getDoubleValuesSV();
if (nullBitmap.getCardinality() < length) {
double sum = 0.0;
long count = 0L;
// TODO: need to update the for-loop terminating condition to: i < length & i < doubleValues.length?
for (int i = 0; i < length; i++) {
if (!nullBitmap.contains(i)) {
sum += doubleValues[i];
count++;
}
AvgPair avgPair = new AvgPair();
forEachNotNull(length, blockValSet, (from, to) -> {
for (int i = from; i < to; i++) {
avgPair.apply(doubleValues[i], 1);
}
setAggregationResult(aggregationResultHolder, sum, count);
});
// Only set the aggregation result when there is at least one non-null input value
if (avgPair.getCount() != 0) {
updateAggregationResult(aggregationResultHolder, avgPair.getSum(), avgPair.getCount());
}
// Note: when all input values re null (nullBitmap.getCardinality() == values.length), avg is null. As a result,
// we don't call setAggregationResult.
} else {
// Serialized AvgPair
byte[][] bytesValues = blockValSet.getBytesValuesSV();
if (nullBitmap.getCardinality() < length) {
double sum = 0.0;
long count = 0L;
// TODO: need to update the for-loop terminating condition to: i < length & i < bytesValues.length?
for (int i = 0; i < length; i++) {
if (!nullBitmap.contains(i)) {
AvgPair value = ObjectSerDeUtils.AVG_PAIR_SER_DE.deserialize(bytesValues[i]);
sum += value.getSum();
count += value.getCount();
}
AvgPair avgPair = new AvgPair();
forEachNotNull(length, blockValSet, (from, to) -> {
for (int i = from; i < to; i++) {
AvgPair value = ObjectSerDeUtils.AVG_PAIR_SER_DE.deserialize(bytesValues[i]);
avgPair.apply(value);
}
setAggregationResult(aggregationResultHolder, sum, count);
});
// Only set the aggregation result when there is at least one non-null input value
if (avgPair.getCount() != 0) {
updateAggregationResult(aggregationResultHolder, avgPair.getSum(), avgPair.getCount());
}
}
}

protected void setAggregationResult(AggregationResultHolder aggregationResultHolder, double sum, long count) {
protected void updateAggregationResult(AggregationResultHolder aggregationResultHolder, double sum, long count) {
AvgPair avgPair = aggregationResultHolder.getResult();
if (avgPair == null) {
aggregationResultHolder.setValue(new AvgPair(sum, count));
Expand All @@ -145,55 +106,23 @@ protected void setAggregationResult(AggregationResultHolder aggregationResultHol
public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
BlockValSet blockValSet = blockValSetMap.get(_expression);
if (_nullHandlingEnabled) {
RoaringBitmap nullBitmap = blockValSet.getNullBitmap();
if (nullBitmap != null && !nullBitmap.isEmpty()) {
aggregateGroupBySVNullHandlingEnabled(length, groupKeyArray, groupByResultHolder, blockValSet, nullBitmap);
return;
}
}

if (blockValSet.getValueType() != DataType.BYTES) {
double[] doubleValues = blockValSet.getDoubleValuesSV();
for (int i = 0; i < length; i++) {
setGroupByResult(groupKeyArray[i], groupByResultHolder, doubleValues[i], 1L);
}
} else {
// Serialized AvgPair
byte[][] bytesValues = blockValSet.getBytesValuesSV();
for (int i = 0; i < length; i++) {
AvgPair avgPair = ObjectSerDeUtils.AVG_PAIR_SER_DE.deserialize(bytesValues[i]);
setGroupByResult(groupKeyArray[i], groupByResultHolder, avgPair.getSum(), avgPair.getCount());
}
}
}

private void aggregateGroupBySVNullHandlingEnabled(int length, int[] groupKeyArray,
GroupByResultHolder groupByResultHolder, BlockValSet blockValSet, RoaringBitmap nullBitmap) {
if (blockValSet.getValueType() != DataType.BYTES) {
double[] doubleValues = blockValSet.getDoubleValuesSV();
// TODO: need to update the for-loop terminating condition to: i < length & i < valueArray.length?
if (nullBitmap.getCardinality() < length) {
for (int i = 0; i < length; i++) {
if (!nullBitmap.contains(i)) {
int groupKey = groupKeyArray[i];
setGroupByResult(groupKey, groupByResultHolder, doubleValues[i], 1L);
}
forEachNotNull(length, blockValSet, (from, to) -> {
for (int i = from; i < to; i++) {
updateGroupByResult(groupKeyArray[i], groupByResultHolder, doubleValues[i], 1L);
}
}
});
} else {
// Serialized AvgPair
byte[][] bytesValues = blockValSet.getBytesValuesSV();
// TODO: need to update the for-loop terminating condition to: i < length & i < valueArray.length?
if (nullBitmap.getCardinality() < length) {
for (int i = 0; i < length; i++) {
if (!nullBitmap.contains(i)) {
int groupKey = groupKeyArray[i];
AvgPair avgPair = ObjectSerDeUtils.AVG_PAIR_SER_DE.deserialize(bytesValues[i]);
setGroupByResult(groupKey, groupByResultHolder, avgPair.getSum(), avgPair.getCount());
}
forEachNotNull(length, blockValSet, (from, to) -> {
for (int i = from; i < to; i++) {
AvgPair avgPair = ObjectSerDeUtils.AVG_PAIR_SER_DE.deserialize(bytesValues[i]);
updateGroupByResult(groupKeyArray[i], groupByResultHolder, avgPair.getSum(), avgPair.getCount());
}
}
});
}
}

Expand All @@ -207,7 +136,7 @@ public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResult
for (int i = 0; i < length; i++) {
double value = doubleValues[i];
for (int groupKey : groupKeysArray[i]) {
setGroupByResult(groupKey, groupByResultHolder, value, 1L);
updateGroupByResult(groupKey, groupByResultHolder, value, 1L);
}
}
} else {
Expand All @@ -218,13 +147,13 @@ public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResult
double sum = avgPair.getSum();
long count = avgPair.getCount();
for (int groupKey : groupKeysArray[i]) {
setGroupByResult(groupKey, groupByResultHolder, sum, count);
updateGroupByResult(groupKey, groupByResultHolder, sum, count);
}
}
}
}

protected void setGroupByResult(int groupKey, GroupByResultHolder groupByResultHolder, double sum, long count) {
protected void updateGroupByResult(int groupKey, GroupByResultHolder groupByResultHolder, double sum, long count) {
AvgPair avgPair = groupByResultHolder.getResult(groupKey);
if (avgPair == null) {
groupByResultHolder.setValueForKey(groupKey, new AvgPair(sum, count));
Expand All @@ -237,7 +166,7 @@ protected void setGroupByResult(int groupKey, GroupByResultHolder groupByResultH
public AvgPair extractAggregationResult(AggregationResultHolder aggregationResultHolder) {
AvgPair avgPair = aggregationResultHolder.getResult();
if (avgPair == null) {
return _nullHandlingEnabled ? null : new AvgPair(0.0, 0L);
return _nullHandlingEnabled ? null : new AvgPair();
}
return avgPair;
}
Expand All @@ -246,7 +175,7 @@ public AvgPair extractAggregationResult(AggregationResultHolder aggregationResul
public AvgPair extractGroupByResult(GroupByResultHolder groupByResultHolder, int groupKey) {
AvgPair avgPair = groupByResultHolder.getResult(groupKey);
if (avgPair == null) {
return _nullHandlingEnabled ? null : new AvgPair(0.0, 0L);
return _nullHandlingEnabled ? null : new AvgPair();
}
return avgPair;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public void aggregate(int length, AggregationResultHolder aggregationResultHolde
}
count += values.length;
}
setAggregationResult(aggregationResultHolder, sum, count);
updateAggregationResult(aggregationResultHolder, sum, count);
}

@Override
Expand Down Expand Up @@ -81,6 +81,6 @@ private void aggregateOnGroupKey(int groupKey, GroupByResultHolder groupByResult
sum += value;
}
long count = values.length;
setGroupByResult(groupKey, groupByResultHolder, sum, count);
updateGroupByResult(groupKey, groupByResultHolder, sum, count);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,14 @@
import org.apache.pinot.core.query.aggregation.groupby.IntGroupByResultHolder;
import org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder;
import org.apache.pinot.spi.data.FieldSpec;
import org.roaringbitmap.RoaringBitmap;


// TODO: change this to implement BaseSingleInputAggregationFunction<Boolean, Boolean> when we get proper
// handling of booleans in serialization - today this would fail because ColumnDataType#convert assumes
// that the boolean is encoded as its stored type (an integer)
public abstract class BaseBooleanAggregationFunction extends BaseSingleInputAggregationFunction<Integer, Integer> {
public abstract class BaseBooleanAggregationFunction extends NullableSingleInputAggregationFunction<Integer, Integer> {

private final BooleanMerge _merger;
private final boolean _nullHandlingEnabled;

protected enum BooleanMerge {
AND {
Expand Down Expand Up @@ -84,8 +82,7 @@ int getDefaultValue() {

protected BaseBooleanAggregationFunction(ExpressionContext expression, boolean nullHandlingEnabled,
BooleanMerge merger) {
super(expression);
_nullHandlingEnabled = nullHandlingEnabled;
super(expression, nullHandlingEnabled);
_merger = merger;
}

Expand Down Expand Up @@ -115,27 +112,22 @@ public void aggregate(int length, AggregationResultHolder aggregationResultHolde

int[] bools = blockValSet.getIntValuesSV();
if (_nullHandlingEnabled) {
int agg = getInt(aggregationResultHolder.getResult());

// early terminate on a per-block level to allow the
// loop below to be more tightly optimized (avoid a branch)
if (_merger.isTerminal(agg)) {
if (_merger.isTerminal(getInt(aggregationResultHolder.getResult()))) {
return;
}

RoaringBitmap nullBitmap = blockValSet.getNullBitmap();
if (nullBitmap == null) {
nullBitmap = new RoaringBitmap();
} else if (nullBitmap.getCardinality() > length) {
return;
}

for (int i = 0; i < length; i++) {
if (!nullBitmap.contains(i)) {
agg = _merger.merge(agg, bools[i]);
aggregationResultHolder.setValue((Object) agg);
Integer aggResult = foldNotNull(length, blockValSet, aggregationResultHolder.getResult(),
(acum, from, to) -> {
int innerBool = acum == null ? _merger.getDefaultValue() : acum;
for (int i = from; i < to; i++) {
innerBool = _merger.merge(innerBool, bools[i]);
}
}
return innerBool;
});

aggregationResultHolder.setValue(aggResult);
} else {
int agg = aggregationResultHolder.getIntResult();

Expand Down Expand Up @@ -164,20 +156,13 @@ public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHol

int[] bools = blockValSet.getIntValuesSV();
if (_nullHandlingEnabled) {
RoaringBitmap nullBitmap = blockValSet.getNullBitmap();
if (nullBitmap == null) {
nullBitmap = new RoaringBitmap();
} else if (nullBitmap.getCardinality() > length) {
return;
}

for (int i = 0; i < length; i++) {
if (!nullBitmap.contains(i)) {
forEachNotNull(length, blockValSet, (from, to) -> {
for (int i = from; i < to; i++) {
int groupByKey = groupKeyArray[i];
int agg = getInt(groupByResultHolder.getResult(groupByKey));
groupByResultHolder.setValueForKey(groupByKey, (Object) _merger.merge(agg, bools[i]));
}
}
});
} else {
for (int i = 0; i < length; i++) {
int groupByKey = groupKeyArray[i];
Expand Down
Loading

0 comments on commit cf52567

Please sign in to comment.