Skip to content

Commit

Permalink
[null-aggr] Add null handling support in mode aggregation (#12227)
Browse files Browse the repository at this point in the history
Support null handling in `mode` aggregation function. When null handling is enabled, null values are ignored when the mode is calculated
  • Loading branch information
gortiz authored Mar 4, 2024
1 parent 9a8fa79 commit 01b0f8c
Show file tree
Hide file tree
Showing 11 changed files with 1,675 additions and 383 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ private AggregationFunctionFactory() {

/**
* Given the function information, returns a new instance of the corresponding aggregation function.
* <p>NOTE: Underscores in the function name are ignored.
* <p>NOTE: Underscores in the function name are ignored in V1.
*/
public static AggregationFunction getAggregationFunction(FunctionContext function, boolean nullHandlingEnabled) {
try {
Expand Down Expand Up @@ -208,7 +208,7 @@ public static AggregationFunction getAggregationFunction(FunctionContext functio
case AVG:
return new AvgAggregationFunction(arguments, nullHandlingEnabled);
case MODE:
return new ModeAggregationFunction(arguments);
return new ModeAggregationFunction(arguments, nullHandlingEnabled);
case FIRSTWITHTIME: {
Preconditions.checkArgument(numArguments == 3,
"FIRST_WITH_TIME expects 3 arguments, got: %s. The function can be used as "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
import com.google.common.base.Preconditions;
import java.util.Collections;
import java.util.List;
import java.util.function.Supplier;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;


/**
Expand Down Expand Up @@ -67,4 +70,22 @@ protected static ExpressionContext verifySingleArgument(List<ExpressionContext>
arguments.size());
return arguments.get(0);
}

protected static <E> E getValue(AggregationResultHolder aggregationResultHolder, Supplier<E> orCreate) {
E result = aggregationResultHolder.getResult();
if (result == null) {
result = orCreate.get();
aggregationResultHolder.setValue(result);
}
return result;
}

protected static <E> E getValue(GroupByResultHolder groupByResultHolder, int groupKey, Supplier<E> orCreate) {
E result = groupByResultHolder.getResult(groupKey);
if (result == null) {
result = orCreate.get();
groupByResultHolder.setValueForKey(groupKey, result);
}
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,15 @@
* </ul>
*/
@SuppressWarnings({"rawtypes", "unchecked"})
public class ModeAggregationFunction extends BaseSingleInputAggregationFunction<Map<? extends Number, Long>, Double> {
public class ModeAggregationFunction
extends NullableSingleInputAggregationFunction<Map<? extends Number, Long>, Double> {

private static final double DEFAULT_FINAL_RESULT = Double.NEGATIVE_INFINITY;

private final MultiModeReducerType _multiModeReducerType;

public ModeAggregationFunction(List<ExpressionContext> arguments) {
super(arguments.get(0));
public ModeAggregationFunction(List<ExpressionContext> arguments, boolean nullHandlingEnabled) {
super(arguments.get(0), nullHandlingEnabled);

int numArguments = arguments.size();
Preconditions.checkArgument(numArguments <= 2, "Mode expects at most 2 arguments, got: %s", numArguments);
Expand Down Expand Up @@ -263,11 +264,14 @@ public void aggregate(int length, AggregationResultHolder aggregationResultHolde
// For dictionary-encoded expression, store dictionary ids into the dictId map
Dictionary dictionary = blockValSet.getDictionary();
if (dictionary != null) {
int[] dictIds = blockValSet.getDictionaryIdsSV();

Int2IntOpenHashMap dictIdValueMap = getDictIdCountMap(aggregationResultHolder, dictionary);
for (int i = 0; i < length; i++) {
dictIdValueMap.merge(dictIds[i], 1, Integer::sum);
}
int[] dictIds = blockValSet.getDictionaryIdsSV();
forEachNotNull(length, blockValSet, (from, to) -> {
for (int i = from; i < to; i++) {
dictIdValueMap.merge(dictIds[i], 1, Integer::sum);
}
});
return;
}

Expand All @@ -278,30 +282,38 @@ public void aggregate(int length, AggregationResultHolder aggregationResultHolde
case INT:
Int2LongOpenHashMap intMap = (Int2LongOpenHashMap) valueMap;
int[] intValues = blockValSet.getIntValuesSV();
for (int i = 0; i < length; i++) {
intMap.merge(intValues[i], 1, Long::sum);
}
forEachNotNull(length, blockValSet, (from, to) -> {
for (int i = from; i < to; i++) {
intMap.merge(intValues[i], 1, Long::sum);
}
});
break;
case LONG:
Long2LongOpenHashMap longMap = (Long2LongOpenHashMap) valueMap;
long[] longValues = blockValSet.getLongValuesSV();
for (int i = 0; i < length; i++) {
longMap.merge(longValues[i], 1, Long::sum);
}
forEachNotNull(length, blockValSet, (from, to) -> {
for (int i = from; i < to; i++) {
longMap.merge(longValues[i], 1, Long::sum);
}
});
break;
case FLOAT:
Float2LongOpenHashMap floatMap = (Float2LongOpenHashMap) valueMap;
float[] floatValues = blockValSet.getFloatValuesSV();
for (int i = 0; i < length; i++) {
floatMap.merge(floatValues[i], 1, Long::sum);
}
forEachNotNull(length, blockValSet, (from, to) -> {
for (int i = from; i < to; i++) {
floatMap.merge(floatValues[i], 1, Long::sum);
}
});
break;
case DOUBLE:
Double2LongOpenHashMap doubleMap = (Double2LongOpenHashMap) valueMap;
double[] doubleValues = blockValSet.getDoubleValuesSV();
for (int i = 0; i < length; i++) {
doubleMap.merge(doubleValues[i], 1, Long::sum);
}
forEachNotNull(length, blockValSet, (from, to) -> {
for (int i = from; i < to; i++) {
doubleMap.merge(doubleValues[i], 1, Long::sum);
}
});
break;
default:
throw new IllegalStateException("Illegal data type for MODE aggregation function: " + storedType);
Expand All @@ -317,9 +329,12 @@ public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHol
Dictionary dictionary = blockValSet.getDictionary();
if (dictionary != null) {
int[] dictIds = blockValSet.getDictionaryIdsSV();
for (int i = 0; i < length; i++) {
getDictIdCountMap(groupByResultHolder, groupKeyArray[i], dictionary).merge(dictIds[i], 1, Integer::sum);
}
forEachNotNull(length, blockValSet, (from, to) -> {
for (int i = from; i < to; i++) {
Int2IntOpenHashMap dictIdCountMap = getDictIdCountMap(groupByResultHolder, groupKeyArray[i], dictionary);
dictIdCountMap.merge(dictIds[i], 1, Integer::sum);
}
});
return;
}

Expand All @@ -328,27 +343,35 @@ public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHol
switch (storedType) {
case INT:
int[] intValues = blockValSet.getIntValuesSV();
for (int i = 0; i < length; i++) {
setValueForGroupKeys(groupByResultHolder, groupKeyArray[i], intValues[i]);
}
forEachNotNull(length, blockValSet, (from, to) -> {
for (int i = from; i < to; i++) {
setValueForGroupKeys(groupByResultHolder, groupKeyArray[i], intValues[i]);
}
});
break;
case LONG:
long[] longValues = blockValSet.getLongValuesSV();
for (int i = 0; i < length; i++) {
setValueForGroupKeys(groupByResultHolder, groupKeyArray[i], longValues[i]);
}
forEachNotNull(length, blockValSet, (from, to) -> {
for (int i = from; i < to; i++) {
setValueForGroupKeys(groupByResultHolder, groupKeyArray[i], longValues[i]);
}
});
break;
case FLOAT:
float[] floatValues = blockValSet.getFloatValuesSV();
for (int i = 0; i < length; i++) {
setValueForGroupKeys(groupByResultHolder, groupKeyArray[i], floatValues[i]);
}
forEachNotNull(length, blockValSet, (from, to) -> {
for (int i = from; i < to; i++) {
setValueForGroupKeys(groupByResultHolder, groupKeyArray[i], floatValues[i]);
}
});
break;
case DOUBLE:
double[] doubleValues = blockValSet.getDoubleValuesSV();
for (int i = 0; i < length; i++) {
setValueForGroupKeys(groupByResultHolder, groupKeyArray[i], doubleValues[i]);
}
forEachNotNull(length, blockValSet, (from, to) -> {
for (int i = from; i < to; i++) {
setValueForGroupKeys(groupByResultHolder, groupKeyArray[i], doubleValues[i]);
}
});
break;
default:
throw new IllegalStateException("Illegal data type for MODE aggregation function: " + storedType);
Expand All @@ -364,11 +387,13 @@ public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResult
Dictionary dictionary = blockValSet.getDictionary();
if (dictionary != null) {
int[] dictIds = blockValSet.getDictionaryIdsSV();
for (int i = 0; i < length; i++) {
for (int groupKey : groupKeysArray[i]) {
getDictIdCountMap(groupByResultHolder, groupKey, dictionary).merge(dictIds[i], 1, Integer::sum);
forEachNotNull(length, blockValSet, (from, to) -> {
for (int i = from; i < to; i++) {
for (int groupKey : groupKeysArray[i]) {
getDictIdCountMap(groupByResultHolder, groupKey, dictionary).merge(dictIds[i], 1, Integer::sum);
}
}
}
});
return;
}

Expand All @@ -377,35 +402,43 @@ public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResult
switch (storedType) {
case INT:
int[] intValues = blockValSet.getIntValuesSV();
for (int i = 0; i < length; i++) {
for (int groupKey : groupKeysArray[i]) {
setValueForGroupKeys(groupByResultHolder, groupKey, intValues[i]);
forEachNotNull(length, blockValSet, (from, to) -> {
for (int i = from; i < to; i++) {
for (int groupKey : groupKeysArray[i]) {
setValueForGroupKeys(groupByResultHolder, groupKey, intValues[i]);
}
}
}
});
break;
case LONG:
long[] longValues = blockValSet.getLongValuesSV();
for (int i = 0; i < length; i++) {
for (int groupKey : groupKeysArray[i]) {
setValueForGroupKeys(groupByResultHolder, groupKey, longValues[i]);
forEachNotNull(length, blockValSet, (from, to) -> {
for (int i = from; i < to; i++) {
for (int groupKey : groupKeysArray[i]) {
setValueForGroupKeys(groupByResultHolder, groupKey, longValues[i]);
}
}
}
});
break;
case FLOAT:
float[] floatValues = blockValSet.getFloatValuesSV();
for (int i = 0; i < length; i++) {
for (int groupKey : groupKeysArray[i]) {
setValueForGroupKeys(groupByResultHolder, groupKey, floatValues[i]);
forEachNotNull(length, blockValSet, (from, to) -> {
for (int i = from; i < to; i++) {
for (int groupKey : groupKeysArray[i]) {
setValueForGroupKeys(groupByResultHolder, groupKey, floatValues[i]);
}
}
}
});
break;
case DOUBLE:
double[] doubleValues = blockValSet.getDoubleValuesSV();
for (int i = 0; i < length; i++) {
for (int groupKey : groupKeysArray[i]) {
setValueForGroupKeys(groupByResultHolder, groupKey, doubleValues[i]);
forEachNotNull(length, blockValSet, (from, to) -> {
for (int i = from; i < to; i++) {
for (int groupKey : groupKeysArray[i]) {
setValueForGroupKeys(groupByResultHolder, groupKey, doubleValues[i]);
}
}
}
});
break;
default:
throw new IllegalStateException("Illegal data type for MODE aggregation function: " + storedType);
Expand Down Expand Up @@ -467,7 +500,11 @@ public ColumnDataType getFinalResultColumnType() {
@Override
public Double extractFinalResult(Map<? extends Number, Long> intermediateResult) {
if (intermediateResult.isEmpty()) {
return DEFAULT_FINAL_RESULT;
if (_nullHandlingEnabled) {
return null;
} else {
return DEFAULT_FINAL_RESULT;
}
} else if (intermediateResult instanceof Int2LongOpenHashMap) {
return extractFinalResult((Int2LongOpenHashMap) intermediateResult);
} else if (intermediateResult instanceof Long2LongOpenHashMap) {
Expand Down
Loading

0 comments on commit 01b0f8c

Please sign in to comment.