Skip to content

Commit

Permalink
Add support for defining custom window frame bounds for window functions
Browse files Browse the repository at this point in the history
  • Loading branch information
yashmayya committed Oct 18, 2024
1 parent 9046af2 commit f1425c0
Show file tree
Hide file tree
Showing 22 changed files with 2,073 additions and 300 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1078,6 +1078,22 @@ public void testFilteredAggregationWithNoValueMatchingAggregationFilterWithOptio
assertEquals(result.get("numRowsResultSet").asInt(), 0);
}

@Test
public void testWindowFunction()
throws Exception {
String query =
"SELECT AirlineID, ArrDelay, DaysSinceEpoch, MAX(ArrDelay) OVER(PARTITION BY AirlineID ORDER BY DaysSinceEpoch "
+ "RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS MaxAirlineDelaySoFar FROM mytable;";
JsonNode jsonNode = postQuery(query);
assertNoError(jsonNode);

query =
"SELECT AirlineID, ArrDelay, DaysSinceEpoch, SUM(ArrDelay) OVER(PARTITION BY AirlineID ORDER BY DaysSinceEpoch "
+ "ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING) AS SumAirlineDelayInWindow FROM mytable;";
jsonNode = postQuery(query);
assertNoError(jsonNode);
}

@Override
protected String getTableName() {
return _tableName;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import javax.annotation.Nullable;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
Expand All @@ -42,6 +44,8 @@
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexWindowBound;
import org.apache.calcite.rex.RexWindowBounds;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.SqlTypeName;
Expand All @@ -65,11 +69,14 @@ public class PinotWindowExchangeNodeInsertRule extends RelOptRule {

// Supported window functions
// OTHER_FUNCTION supported are: BOOL_AND, BOOL_OR
private static final Set<SqlKind> SUPPORTED_WINDOW_FUNCTION_KIND =
Set.of(SqlKind.SUM, SqlKind.SUM0, SqlKind.MIN, SqlKind.MAX, SqlKind.COUNT, SqlKind.ROW_NUMBER, SqlKind.RANK,
private static final EnumSet<SqlKind> SUPPORTED_WINDOW_FUNCTION_KIND =
EnumSet.of(SqlKind.SUM, SqlKind.SUM0, SqlKind.MIN, SqlKind.MAX, SqlKind.COUNT, SqlKind.ROW_NUMBER, SqlKind.RANK,
SqlKind.DENSE_RANK, SqlKind.LAG, SqlKind.LEAD, SqlKind.FIRST_VALUE, SqlKind.LAST_VALUE,
SqlKind.OTHER_FUNCTION);

private static final EnumSet<SqlKind> RANK_BASED_FUNCTION_KIND =
EnumSet.of(SqlKind.ROW_NUMBER, SqlKind.RANK, SqlKind.DENSE_RANK);

public PinotWindowExchangeNodeInsertRule(RelBuilderFactory factory) {
super(operand(Window.class, any()), factory, null);
}
Expand Down Expand Up @@ -144,60 +151,88 @@ public void onMatch(RelOptRuleCall call) {
List.of(windowGroup)));
}

/**
* Replaces the reference to literal arguments in the window group with the actual literal values.
* NOTE: {@link Window} has a field called "constants" which contains the literal values. If the input reference is
* beyond the window input size, it is a reference to the constants.
*/
private Window.Group updateLiteralArgumentsInWindowGroup(Window window) {
Window.Group oldWindowGroup = window.groups.get(0);
int windowInputSize = window.getInput().getRowType().getFieldCount();
ImmutableList<Window.RexWinAggCall> oldAggCalls = oldWindowGroup.aggCalls;
List<Window.RexWinAggCall> newAggCallWindow = new ArrayList<>(oldAggCalls.size());
boolean aggCallChanged = false;
for (Window.RexWinAggCall oldAggCall : oldAggCalls) {
RelNode input = ((HepRelVertex) window.getInput()).getCurrentRel();
int numInputFields = input.getRowType().getFieldCount();
List<RexNode> projects = input instanceof Project ? ((Project) input).getProjects() : null;

List<Window.RexWinAggCall> newAggCallWindow = new ArrayList<>(oldWindowGroup.aggCalls.size());
boolean windowChanged = false;
for (Window.RexWinAggCall oldAggCall : oldWindowGroup.aggCalls) {
boolean changed = false;
List<RexNode> oldAggCallArgList = oldAggCall.getOperands();
List<RexNode> rexList = new ArrayList<>(oldAggCallArgList.size());
for (RexNode rexNode : oldAggCallArgList) {
RexNode newRexNode = rexNode;
if (rexNode instanceof RexInputRef) {
RexInputRef inputRef = (RexInputRef) rexNode;
int inputRefIndex = inputRef.getIndex();
// If the input reference is greater than the window input size, it is a reference to the constants
if (inputRefIndex >= windowInputSize) {
newRexNode = window.constants.get(inputRefIndex - windowInputSize);
changed = true;
aggCallChanged = true;
} else {
RelNode windowInputRelNode = ((HepRelVertex) window.getInput()).getCurrentRel();
if (windowInputRelNode instanceof LogicalProject) {
RexNode inputRefRexNode = ((LogicalProject) windowInputRelNode).getProjects().get(inputRefIndex);
if (inputRefRexNode instanceof RexLiteral) {
// If the input reference is a literal, replace it with the literal value
newRexNode = inputRefRexNode;
changed = true;
aggCallChanged = true;
}
}
}
List<RexNode> oldOperands = oldAggCall.getOperands();
List<RexNode> newOperands = new ArrayList<>(oldOperands.size());
for (RexNode oldOperand : oldOperands) {
RexLiteral literal = getLiteral(oldOperand, numInputFields, window.constants, projects);
if (literal != null) {
newOperands.add(literal);
changed = true;
windowChanged = true;
} else {
newOperands.add(oldOperand);
}
rexList.add(newRexNode);
}
if (changed) {
newAggCallWindow.add(
new Window.RexWinAggCall((SqlAggFunction) oldAggCall.getOperator(), oldAggCall.type, rexList,
new Window.RexWinAggCall((SqlAggFunction) oldAggCall.getOperator(), oldAggCall.type, newOperands,
oldAggCall.ordinal, oldAggCall.distinct, oldAggCall.ignoreNulls));
} else {
newAggCallWindow.add(oldAggCall);
}
}
if (aggCallChanged) {
return new Window.Group(oldWindowGroup.keys, oldWindowGroup.isRows, oldWindowGroup.lowerBound,
oldWindowGroup.upperBound, oldWindowGroup.orderKeys, newAggCallWindow);

RexWindowBound lowerBound = oldWindowGroup.lowerBound;
RexNode offset = lowerBound.getOffset();
if (offset != null) {
RexLiteral literal = getLiteral(offset, numInputFields, window.constants, projects);
if (literal != null) {
lowerBound = lowerBound.isPreceding() ? RexWindowBounds.preceding(literal) : RexWindowBounds.following(literal);
windowChanged = true;
}
}
RexWindowBound upperBound = oldWindowGroup.upperBound;
offset = upperBound.getOffset();
if (offset != null) {
RexLiteral literal = getLiteral(offset, numInputFields, window.constants, projects);
if (literal != null) {
upperBound = lowerBound.isFollowing() ? RexWindowBounds.following(literal) : RexWindowBounds.preceding(literal);
windowChanged = true;
}
}
return oldWindowGroup;

return windowChanged ? new Window.Group(oldWindowGroup.keys, oldWindowGroup.isRows, lowerBound, upperBound,
oldWindowGroup.orderKeys, newAggCallWindow) : oldWindowGroup;
}

@Nullable
private RexLiteral getLiteral(RexNode rexNode, int numInputFields, ImmutableList<RexLiteral> constants,
@Nullable List<RexNode> projects) {
if (!(rexNode instanceof RexInputRef)) {
return null;
}
int index = ((RexInputRef) rexNode).getIndex();
if (index >= numInputFields) {
return constants.get(index - numInputFields);
}
if (projects != null) {
RexNode project = projects.get(index);
if (project instanceof RexLiteral) {
return (RexLiteral) project;
}
}
return null;
}

private void validateWindows(Window window) {
int numGroups = window.groups.size();
// For Phase 1 we only handle single window groups
Preconditions.checkState(numGroups <= 1,
Preconditions.checkState(numGroups == 1,
String.format("Currently only 1 window group is supported, query has %d groups", numGroups));

// Validate that only supported window aggregation functions are present
Expand All @@ -209,33 +244,23 @@ private void validateWindows(Window window) {
}

private void validateWindowAggCallsSupported(Window.Group windowGroup) {
for (int i = 0; i < windowGroup.aggCalls.size(); i++) {
Window.RexWinAggCall aggCall = windowGroup.aggCalls.get(i);
for (Window.RexWinAggCall aggCall : windowGroup.aggCalls) {
SqlKind aggKind = aggCall.getKind();
Preconditions.checkState(SUPPORTED_WINDOW_FUNCTION_KIND.contains(aggKind),
String.format("Unsupported Window function kind: %s. Only aggregation functions are supported!", aggKind));
}
}

private void validateWindowFrames(Window.Group windowGroup) {
// Has ROWS only aggregation call kind (e.g. ROW_NUMBER)?
boolean isRowsOnlyTypeAggregateCall = isRowsOnlyAggregationCallType(windowGroup.aggCalls);
// For Phase 1 only the default frame is supported
Preconditions.checkState(!windowGroup.isRows || isRowsOnlyTypeAggregateCall,
"Default frame must be of type RANGE and not ROWS unless this is a ROWS only aggregation function");
Preconditions.checkState(windowGroup.lowerBound.isPreceding() && windowGroup.lowerBound.isUnbounded(),
String.format("Lower bound must be UNBOUNDED PRECEDING but it is: %s", windowGroup.lowerBound));
if (windowGroup.orderKeys.getKeys().isEmpty() && !isRowsOnlyTypeAggregateCall) {
Preconditions.checkState(windowGroup.upperBound.isFollowing() && windowGroup.upperBound.isUnbounded(),
String.format("Upper bound must be UNBOUNDED FOLLOWING but it is: %s", windowGroup.upperBound));
} else {
Preconditions.checkState(windowGroup.upperBound.isCurrentRow(),
String.format("Upper bound must be CURRENT ROW but it is: %s", windowGroup.upperBound));
}
}
RexWindowBound lowerBound = windowGroup.lowerBound;
RexWindowBound upperBound = windowGroup.upperBound;

private boolean isRowsOnlyAggregationCallType(ImmutableList<Window.RexWinAggCall> aggCalls) {
return aggCalls.stream().anyMatch(aggCall -> aggCall.getKind().equals(SqlKind.ROW_NUMBER));
boolean hasOffset = (lowerBound.isPreceding() && !lowerBound.isUnbounded()) || (upperBound.isFollowing()
&& !upperBound.isUnbounded());

if (!windowGroup.isRows) {
Preconditions.checkState(!hasOffset, "RANGE window frame with offset PRECEDING / FOLLOWING is not supported");
}
}

private boolean isPartitionByOnlyQuery(Window.Group windowGroup) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,10 +311,6 @@ public Void visitWindow(WindowNode node, Void context) {

ImmutableBitSet keys = ImmutableBitSet.of(node.getKeys());
boolean isRow = node.getWindowFrameType() == WindowNode.WindowFrameType.ROWS;
// As explained in RelToPlanNodeConverter, Pinot only supports UNBOUND_PRECEDING
RexWindowBound lowerBound = RexWindowBounds.UNBOUNDED_PRECEDING;
RexWindowBound upperBound = node.getUpperBound() == Integer.MAX_VALUE ? RexWindowBounds.UNBOUNDED_FOLLOWING
: RexWindowBounds.CURRENT_ROW;
RelCollation orderKeys = RelCollations.of(node.getCollations());

List<Window.RexWinAggCall> aggCalls = new ArrayList<>();
Expand All @@ -332,7 +328,9 @@ public Void visitWindow(WindowNode node, Void context) {
aggCalls.add(winCall);
}

Window.Group group = new Window.Group(keys, isRow, lowerBound, upperBound, orderKeys, aggCalls);
Window.Group group =
new Window.Group(keys, isRow, getWindowBound(node.getLowerBound()), getWindowBound(node.getUpperBound()),
orderKeys, aggCalls);

List<RexLiteral> constants =
node.getConstants().stream().map(constant -> RexExpressionUtils.toRexLiteral(_builder, constant))
Expand All @@ -351,6 +349,20 @@ public Void visitWindow(WindowNode node, Void context) {
return null;
}

private RexWindowBound getWindowBound(int bound) {
if (bound == Integer.MIN_VALUE) {
return RexWindowBounds.UNBOUNDED_PRECEDING;
} else if (bound == Integer.MAX_VALUE) {
return RexWindowBounds.UNBOUNDED_FOLLOWING;
} else if (bound == 0) {
return RexWindowBounds.CURRENT_ROW;
} else if (bound < 0) {
return RexWindowBounds.preceding(_builder.literal(-bound));
} else {
return RexWindowBounds.following(_builder.literal(bound));
}
}

@Override
public Void visitSetOp(SetOpNode node, Void context) {
List<RelNode> inputs = inputsAsList(node);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,18 +211,35 @@ private WindowNode convertLogicalWindow(LogicalWindow node) {
WindowNode.WindowFrameType windowFrameType =
windowGroup.isRows ? WindowNode.WindowFrameType.ROWS : WindowNode.WindowFrameType.RANGE;

// TODO: For now only the default frame is supported. Add support for custom frames including rows support.
// Frame literals come in the constants from the LogicalWindow and the bound.getOffset() stores the
// InputRef to the constants array offset by the input array length. These need to be extracted here and
// set to the bounds.
// Lower bound can only be unbounded preceding for now, set to Integer.MIN_VALUE
// Change PlanNodeToRelConverted once this limitation is removed here
int lowerBound = Integer.MIN_VALUE;
// Upper bound can only be unbounded following or current row for now
int upperBound = windowGroup.upperBound.isUnbounded() ? Integer.MAX_VALUE : 0;
int lowerBound;
if (windowGroup.lowerBound.isUnbounded()) {
// Lower bound can't be unbounded following
lowerBound = Integer.MIN_VALUE;
} else if (windowGroup.lowerBound.isCurrentRow()) {
lowerBound = 0;
} else {
// The literal value is extracted from the constants in the PinotWindowExchangeNodeInsertRule
RexLiteral offset = (RexLiteral) windowGroup.lowerBound.getOffset();
lowerBound = offset == null ? Integer.MIN_VALUE
: (windowGroup.lowerBound.isPreceding() ? -1 * RexExpressionUtils.getValueAsInt(offset)
: RexExpressionUtils.getValueAsInt(offset));
}
int upperBound;
if (windowGroup.upperBound.isUnbounded()) {
// Upper bound can't be unbounded preceding
upperBound = Integer.MAX_VALUE;
} else if (windowGroup.upperBound.isCurrentRow()) {
upperBound = 0;
} else {
// The literal value is extracted from the constants in the PinotWindowExchangeNodeInsertRule
RexLiteral offset = (RexLiteral) windowGroup.upperBound.getOffset();
upperBound = offset == null ? Integer.MAX_VALUE
: (windowGroup.upperBound.isFollowing() ? RexExpressionUtils.getValueAsInt(offset)
: -1 * RexExpressionUtils.getValueAsInt(offset));
}

// TODO: Constants are used to store constants needed such as the frame literals. For now just save this, need to
// extract the constant values into bounds as a part of frame support.
// TODO: The constants are already extracted in the PinotWindowExchangeNodeInsertRule, we can remove them from
// the WindowNode and plan serde.
List<RexExpression.Literal> constants = new ArrayList<>(node.constants.size());
for (RexLiteral constant : node.constants) {
constants.add(RexExpressionUtils.fromRexLiteral(constant));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ public class WindowNode extends BasePlanNode {
private final List<RelFieldCollation> _collations;
private final List<RexExpression.FunctionCall> _aggCalls;
private final WindowFrameType _windowFrameType;
// Both these bounds are relative to current row; 0 means current row, -1 means previous row, 1 means next row, etc.
private final int _lowerBound;
private final int _upperBound;
private final List<RexExpression.Literal> _constants;
Expand Down
Loading

0 comments on commit f1425c0

Please sign in to comment.