From f1425c0b0a344aeeaa144cefdb12ee8a55991c8b Mon Sep 17 00:00:00 2001 From: Yash Mayya Date: Thu, 17 Oct 2024 14:55:09 +0530 Subject: [PATCH] Add support for defining custom window frame bounds for window functions --- .../MultiStageEngineIntegrationTest.java | 16 + .../PinotWindowExchangeNodeInsertRule.java | 139 +- .../logical/PlanNodeToRelConverter.java | 22 +- .../logical/RelToPlanNodeConverter.java | 39 +- .../query/planner/plannode/WindowNode.java | 1 + .../pinot/query/QueryCompilationTest.java | 81 + .../queries/WindowFunctionPlans.json | 12 +- .../operator/WindowAggregateOperator.java | 76 +- .../operator/window/WindowFunction.java | 7 +- .../window/WindowFunctionFactory.java | 19 +- .../aggregate/AggregateWindowFunction.java | 243 ++- .../window/aggregate/WindowFrame.java | 75 + .../window/range/DenseRankWindowFunction.java | 7 +- ...tion.java => RankBasedWindowFunction.java} | 15 +- .../window/range/RankWindowFunction.java | 7 +- .../window/range/RowNumberWindowFunction.java | 7 +- .../value/FirstValueWindowFunction.java | 69 +- .../window/value/LagValueWindowFunction.java | 8 +- .../window/value/LastValueWindowFunction.java | 73 +- .../window/value/LeadValueWindowFunction.java | 8 +- .../window/value/ValueWindowFunction.java | 5 +- .../operator/WindowAggregateOperatorTest.java | 1444 ++++++++++++++++- 22 files changed, 2073 insertions(+), 300 deletions(-) create mode 100644 pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/aggregate/WindowFrame.java rename pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/{RangeWindowFunction.java => RankBasedWindowFunction.java} (78%) diff --git a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiStageEngineIntegrationTest.java b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiStageEngineIntegrationTest.java index b450612d45e..e91f9cc719f 100644 --- a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiStageEngineIntegrationTest.java +++ b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiStageEngineIntegrationTest.java @@ -1078,6 +1078,22 @@ public void testFilteredAggregationWithNoValueMatchingAggregationFilterWithOptio assertEquals(result.get("numRowsResultSet").asInt(), 0); } + @Test + public void testWindowFunction() + throws Exception { + String query = + "SELECT AirlineID, ArrDelay, DaysSinceEpoch, MAX(ArrDelay) OVER(PARTITION BY AirlineID ORDER BY DaysSinceEpoch " + + "RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS MaxAirlineDelaySoFar FROM mytable;"; + JsonNode jsonNode = postQuery(query); + assertNoError(jsonNode); + + query = + "SELECT AirlineID, ArrDelay, DaysSinceEpoch, SUM(ArrDelay) OVER(PARTITION BY AirlineID ORDER BY DaysSinceEpoch " + + "ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING) AS SumAirlineDelayInWindow FROM mytable;"; + jsonNode = postQuery(query); + assertNoError(jsonNode); + } + @Override protected String getTableName() { return _tableName; diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotWindowExchangeNodeInsertRule.java b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotWindowExchangeNodeInsertRule.java index 2050aa3b8a2..de6fa35f5a5 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotWindowExchangeNodeInsertRule.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotWindowExchangeNodeInsertRule.java @@ -22,9 +22,11 @@ import com.google.common.collect.ImmutableList; import java.util.ArrayList; import java.util.Collections; +import java.util.EnumSet; import java.util.HashSet; import java.util.List; import java.util.Set; +import javax.annotation.Nullable; import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelOptRule; import org.apache.calcite.plan.RelOptRuleCall; @@ -42,6 +44,8 @@ import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexWindowBound; +import org.apache.calcite.rex.RexWindowBounds; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.type.SqlTypeName; @@ -65,11 +69,14 @@ public class PinotWindowExchangeNodeInsertRule extends RelOptRule { // Supported window functions // OTHER_FUNCTION supported are: BOOL_AND, BOOL_OR - private static final Set SUPPORTED_WINDOW_FUNCTION_KIND = - Set.of(SqlKind.SUM, SqlKind.SUM0, SqlKind.MIN, SqlKind.MAX, SqlKind.COUNT, SqlKind.ROW_NUMBER, SqlKind.RANK, + private static final EnumSet SUPPORTED_WINDOW_FUNCTION_KIND = + EnumSet.of(SqlKind.SUM, SqlKind.SUM0, SqlKind.MIN, SqlKind.MAX, SqlKind.COUNT, SqlKind.ROW_NUMBER, SqlKind.RANK, SqlKind.DENSE_RANK, SqlKind.LAG, SqlKind.LEAD, SqlKind.FIRST_VALUE, SqlKind.LAST_VALUE, SqlKind.OTHER_FUNCTION); + private static final EnumSet RANK_BASED_FUNCTION_KIND = + EnumSet.of(SqlKind.ROW_NUMBER, SqlKind.RANK, SqlKind.DENSE_RANK); + public PinotWindowExchangeNodeInsertRule(RelBuilderFactory factory) { super(operand(Window.class, any()), factory, null); } @@ -144,60 +151,88 @@ public void onMatch(RelOptRuleCall call) { List.of(windowGroup))); } + /** + * Replaces the reference to literal arguments in the window group with the actual literal values. + * NOTE: {@link Window} has a field called "constants" which contains the literal values. If the input reference is + * beyond the window input size, it is a reference to the constants. + */ private Window.Group updateLiteralArgumentsInWindowGroup(Window window) { Window.Group oldWindowGroup = window.groups.get(0); - int windowInputSize = window.getInput().getRowType().getFieldCount(); - ImmutableList oldAggCalls = oldWindowGroup.aggCalls; - List newAggCallWindow = new ArrayList<>(oldAggCalls.size()); - boolean aggCallChanged = false; - for (Window.RexWinAggCall oldAggCall : oldAggCalls) { + RelNode input = ((HepRelVertex) window.getInput()).getCurrentRel(); + int numInputFields = input.getRowType().getFieldCount(); + List projects = input instanceof Project ? ((Project) input).getProjects() : null; + + List newAggCallWindow = new ArrayList<>(oldWindowGroup.aggCalls.size()); + boolean windowChanged = false; + for (Window.RexWinAggCall oldAggCall : oldWindowGroup.aggCalls) { boolean changed = false; - List oldAggCallArgList = oldAggCall.getOperands(); - List rexList = new ArrayList<>(oldAggCallArgList.size()); - for (RexNode rexNode : oldAggCallArgList) { - RexNode newRexNode = rexNode; - if (rexNode instanceof RexInputRef) { - RexInputRef inputRef = (RexInputRef) rexNode; - int inputRefIndex = inputRef.getIndex(); - // If the input reference is greater than the window input size, it is a reference to the constants - if (inputRefIndex >= windowInputSize) { - newRexNode = window.constants.get(inputRefIndex - windowInputSize); - changed = true; - aggCallChanged = true; - } else { - RelNode windowInputRelNode = ((HepRelVertex) window.getInput()).getCurrentRel(); - if (windowInputRelNode instanceof LogicalProject) { - RexNode inputRefRexNode = ((LogicalProject) windowInputRelNode).getProjects().get(inputRefIndex); - if (inputRefRexNode instanceof RexLiteral) { - // If the input reference is a literal, replace it with the literal value - newRexNode = inputRefRexNode; - changed = true; - aggCallChanged = true; - } - } - } + List oldOperands = oldAggCall.getOperands(); + List newOperands = new ArrayList<>(oldOperands.size()); + for (RexNode oldOperand : oldOperands) { + RexLiteral literal = getLiteral(oldOperand, numInputFields, window.constants, projects); + if (literal != null) { + newOperands.add(literal); + changed = true; + windowChanged = true; + } else { + newOperands.add(oldOperand); } - rexList.add(newRexNode); } if (changed) { newAggCallWindow.add( - new Window.RexWinAggCall((SqlAggFunction) oldAggCall.getOperator(), oldAggCall.type, rexList, + new Window.RexWinAggCall((SqlAggFunction) oldAggCall.getOperator(), oldAggCall.type, newOperands, oldAggCall.ordinal, oldAggCall.distinct, oldAggCall.ignoreNulls)); } else { newAggCallWindow.add(oldAggCall); } } - if (aggCallChanged) { - return new Window.Group(oldWindowGroup.keys, oldWindowGroup.isRows, oldWindowGroup.lowerBound, - oldWindowGroup.upperBound, oldWindowGroup.orderKeys, newAggCallWindow); + + RexWindowBound lowerBound = oldWindowGroup.lowerBound; + RexNode offset = lowerBound.getOffset(); + if (offset != null) { + RexLiteral literal = getLiteral(offset, numInputFields, window.constants, projects); + if (literal != null) { + lowerBound = lowerBound.isPreceding() ? RexWindowBounds.preceding(literal) : RexWindowBounds.following(literal); + windowChanged = true; + } + } + RexWindowBound upperBound = oldWindowGroup.upperBound; + offset = upperBound.getOffset(); + if (offset != null) { + RexLiteral literal = getLiteral(offset, numInputFields, window.constants, projects); + if (literal != null) { + upperBound = lowerBound.isFollowing() ? RexWindowBounds.following(literal) : RexWindowBounds.preceding(literal); + windowChanged = true; + } } - return oldWindowGroup; + + return windowChanged ? new Window.Group(oldWindowGroup.keys, oldWindowGroup.isRows, lowerBound, upperBound, + oldWindowGroup.orderKeys, newAggCallWindow) : oldWindowGroup; + } + + @Nullable + private RexLiteral getLiteral(RexNode rexNode, int numInputFields, ImmutableList constants, + @Nullable List projects) { + if (!(rexNode instanceof RexInputRef)) { + return null; + } + int index = ((RexInputRef) rexNode).getIndex(); + if (index >= numInputFields) { + return constants.get(index - numInputFields); + } + if (projects != null) { + RexNode project = projects.get(index); + if (project instanceof RexLiteral) { + return (RexLiteral) project; + } + } + return null; } private void validateWindows(Window window) { int numGroups = window.groups.size(); // For Phase 1 we only handle single window groups - Preconditions.checkState(numGroups <= 1, + Preconditions.checkState(numGroups == 1, String.format("Currently only 1 window group is supported, query has %d groups", numGroups)); // Validate that only supported window aggregation functions are present @@ -209,8 +244,7 @@ private void validateWindows(Window window) { } private void validateWindowAggCallsSupported(Window.Group windowGroup) { - for (int i = 0; i < windowGroup.aggCalls.size(); i++) { - Window.RexWinAggCall aggCall = windowGroup.aggCalls.get(i); + for (Window.RexWinAggCall aggCall : windowGroup.aggCalls) { SqlKind aggKind = aggCall.getKind(); Preconditions.checkState(SUPPORTED_WINDOW_FUNCTION_KIND.contains(aggKind), String.format("Unsupported Window function kind: %s. Only aggregation functions are supported!", aggKind)); @@ -218,24 +252,15 @@ private void validateWindowAggCallsSupported(Window.Group windowGroup) { } private void validateWindowFrames(Window.Group windowGroup) { - // Has ROWS only aggregation call kind (e.g. ROW_NUMBER)? - boolean isRowsOnlyTypeAggregateCall = isRowsOnlyAggregationCallType(windowGroup.aggCalls); - // For Phase 1 only the default frame is supported - Preconditions.checkState(!windowGroup.isRows || isRowsOnlyTypeAggregateCall, - "Default frame must be of type RANGE and not ROWS unless this is a ROWS only aggregation function"); - Preconditions.checkState(windowGroup.lowerBound.isPreceding() && windowGroup.lowerBound.isUnbounded(), - String.format("Lower bound must be UNBOUNDED PRECEDING but it is: %s", windowGroup.lowerBound)); - if (windowGroup.orderKeys.getKeys().isEmpty() && !isRowsOnlyTypeAggregateCall) { - Preconditions.checkState(windowGroup.upperBound.isFollowing() && windowGroup.upperBound.isUnbounded(), - String.format("Upper bound must be UNBOUNDED FOLLOWING but it is: %s", windowGroup.upperBound)); - } else { - Preconditions.checkState(windowGroup.upperBound.isCurrentRow(), - String.format("Upper bound must be CURRENT ROW but it is: %s", windowGroup.upperBound)); - } - } + RexWindowBound lowerBound = windowGroup.lowerBound; + RexWindowBound upperBound = windowGroup.upperBound; - private boolean isRowsOnlyAggregationCallType(ImmutableList aggCalls) { - return aggCalls.stream().anyMatch(aggCall -> aggCall.getKind().equals(SqlKind.ROW_NUMBER)); + boolean hasOffset = (lowerBound.isPreceding() && !lowerBound.isUnbounded()) || (upperBound.isFollowing() + && !upperBound.isUnbounded()); + + if (!windowGroup.isRows) { + Preconditions.checkState(!hasOffset, "RANGE window frame with offset PRECEDING / FOLLOWING is not supported"); + } } private boolean isPartitionByOnlyQuery(Window.Group windowGroup) { diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/PlanNodeToRelConverter.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/PlanNodeToRelConverter.java index 0baa29ce41e..13bbd791b29 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/PlanNodeToRelConverter.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/PlanNodeToRelConverter.java @@ -311,10 +311,6 @@ public Void visitWindow(WindowNode node, Void context) { ImmutableBitSet keys = ImmutableBitSet.of(node.getKeys()); boolean isRow = node.getWindowFrameType() == WindowNode.WindowFrameType.ROWS; - // As explained in RelToPlanNodeConverter, Pinot only supports UNBOUND_PRECEDING - RexWindowBound lowerBound = RexWindowBounds.UNBOUNDED_PRECEDING; - RexWindowBound upperBound = node.getUpperBound() == Integer.MAX_VALUE ? RexWindowBounds.UNBOUNDED_FOLLOWING - : RexWindowBounds.CURRENT_ROW; RelCollation orderKeys = RelCollations.of(node.getCollations()); List aggCalls = new ArrayList<>(); @@ -332,7 +328,9 @@ public Void visitWindow(WindowNode node, Void context) { aggCalls.add(winCall); } - Window.Group group = new Window.Group(keys, isRow, lowerBound, upperBound, orderKeys, aggCalls); + Window.Group group = + new Window.Group(keys, isRow, getWindowBound(node.getLowerBound()), getWindowBound(node.getUpperBound()), + orderKeys, aggCalls); List constants = node.getConstants().stream().map(constant -> RexExpressionUtils.toRexLiteral(_builder, constant)) @@ -351,6 +349,20 @@ public Void visitWindow(WindowNode node, Void context) { return null; } + private RexWindowBound getWindowBound(int bound) { + if (bound == Integer.MIN_VALUE) { + return RexWindowBounds.UNBOUNDED_PRECEDING; + } else if (bound == Integer.MAX_VALUE) { + return RexWindowBounds.UNBOUNDED_FOLLOWING; + } else if (bound == 0) { + return RexWindowBounds.CURRENT_ROW; + } else if (bound < 0) { + return RexWindowBounds.preceding(_builder.literal(-bound)); + } else { + return RexWindowBounds.following(_builder.literal(bound)); + } + } + @Override public Void visitSetOp(SetOpNode node, Void context) { List inputs = inputsAsList(node); diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java index 5991b44ed8b..1885261417c 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java @@ -211,18 +211,35 @@ private WindowNode convertLogicalWindow(LogicalWindow node) { WindowNode.WindowFrameType windowFrameType = windowGroup.isRows ? WindowNode.WindowFrameType.ROWS : WindowNode.WindowFrameType.RANGE; - // TODO: For now only the default frame is supported. Add support for custom frames including rows support. - // Frame literals come in the constants from the LogicalWindow and the bound.getOffset() stores the - // InputRef to the constants array offset by the input array length. These need to be extracted here and - // set to the bounds. - // Lower bound can only be unbounded preceding for now, set to Integer.MIN_VALUE - // Change PlanNodeToRelConverted once this limitation is removed here - int lowerBound = Integer.MIN_VALUE; - // Upper bound can only be unbounded following or current row for now - int upperBound = windowGroup.upperBound.isUnbounded() ? Integer.MAX_VALUE : 0; + int lowerBound; + if (windowGroup.lowerBound.isUnbounded()) { + // Lower bound can't be unbounded following + lowerBound = Integer.MIN_VALUE; + } else if (windowGroup.lowerBound.isCurrentRow()) { + lowerBound = 0; + } else { + // The literal value is extracted from the constants in the PinotWindowExchangeNodeInsertRule + RexLiteral offset = (RexLiteral) windowGroup.lowerBound.getOffset(); + lowerBound = offset == null ? Integer.MIN_VALUE + : (windowGroup.lowerBound.isPreceding() ? -1 * RexExpressionUtils.getValueAsInt(offset) + : RexExpressionUtils.getValueAsInt(offset)); + } + int upperBound; + if (windowGroup.upperBound.isUnbounded()) { + // Upper bound can't be unbounded preceding + upperBound = Integer.MAX_VALUE; + } else if (windowGroup.upperBound.isCurrentRow()) { + upperBound = 0; + } else { + // The literal value is extracted from the constants in the PinotWindowExchangeNodeInsertRule + RexLiteral offset = (RexLiteral) windowGroup.upperBound.getOffset(); + upperBound = offset == null ? Integer.MAX_VALUE + : (windowGroup.upperBound.isFollowing() ? RexExpressionUtils.getValueAsInt(offset) + : -1 * RexExpressionUtils.getValueAsInt(offset)); + } - // TODO: Constants are used to store constants needed such as the frame literals. For now just save this, need to - // extract the constant values into bounds as a part of frame support. + // TODO: The constants are already extracted in the PinotWindowExchangeNodeInsertRule, we can remove them from + // the WindowNode and plan serde. List constants = new ArrayList<>(node.constants.size()); for (RexLiteral constant : node.constants) { constants.add(RexExpressionUtils.fromRexLiteral(constant)); diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/WindowNode.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/WindowNode.java index 3b726e5f5f1..cfee6f36cb9 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/WindowNode.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/WindowNode.java @@ -30,6 +30,7 @@ public class WindowNode extends BasePlanNode { private final List _collations; private final List _aggCalls; private final WindowFrameType _windowFrameType; + // Both these bounds are relative to current row; 0 means current row, -1 means previous row, 1 means next row, etc. private final int _lowerBound; private final int _upperBound; private final List _constants; diff --git a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java index f40f474f880..c43f5434402 100644 --- a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java +++ b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java @@ -81,6 +81,7 @@ public void testAggregateCaseToFilter() { String query = "EXPLAIN PLAN FOR SELECT SUM(CASE WHEN col1 = 'a' THEN 1 ELSE 0 END) FROM a"; String explain = _queryEnvironment.explainQuery(query, RANDOM_REQUEST_ID_GEN.nextLong()); + //@formatter:off assertEquals(explain, "Execution Plan\n" + "LogicalProject(EXPR$0=[CASE(=($1, 0), null:BIGINT, $0)])\n" @@ -89,6 +90,7 @@ public void testAggregateCaseToFilter() { + " PinotLogicalAggregate(group=[{}], agg#0=[COUNT() FILTER $0], agg#1=[COUNT()])\n" + " LogicalProject($f1=[=($0, _UTF-8'a')])\n" + " LogicalTableScan(table=[[default, a]])\n"); + //@formatter:on } private static void assertGroupBySingletonAfterJoin(DispatchableSubPlan dispatchableSubPlan, boolean shouldRewrite) { @@ -392,6 +394,85 @@ public void testDuplicateWithAlias() { assertTrue(e.getCause().getMessage().contains("Duplicate alias in WITH: 'tmp'")); } + @Test + public void testWindowFunctionsWithCustomWindowFrame() { + String queryWithDefaultWindow = "SELECT col1, col2, RANK() OVER (PARTITION BY col1 ORDER BY col2) FROM a"; + _queryEnvironment.planQuery(queryWithDefaultWindow); + + String sumQueryWithCustomRowsWindow = + "SELECT col1, col2, SUM(col3) OVER (PARTITION BY col1 ORDER BY col2 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING)" + + " FROM a"; + _queryEnvironment.planQuery(sumQueryWithCustomRowsWindow); + + String queryWithUnboundedFollowingAsLowerBound = + "SELECT col1, col2, SUM(col3) OVER (PARTITION BY col1 ORDER BY col2 ROWS BETWEEN UNBOUNDED FOLLOWING AND " + + "UNBOUNDED FOLLOWING) FROM a"; + RuntimeException e = expectThrows(RuntimeException.class, + () -> _queryEnvironment.planQuery(queryWithUnboundedFollowingAsLowerBound)); + assertTrue( + e.getCause().getMessage().contains("UNBOUNDED FOLLOWING cannot be specified for the lower frame boundary")); + + String queryWithUnboundedPrecedingAsUpperBound = + "SELECT col1, col2, SUM(col3) OVER (PARTITION BY col1 ORDER BY col2 ROWS BETWEEN UNBOUNDED PRECEDING AND " + + "UNBOUNDED PRECEDING) FROM a"; + e = expectThrows(RuntimeException.class, + () -> _queryEnvironment.planQuery(queryWithUnboundedPrecedingAsUpperBound)); + assertTrue( + e.getCause().getMessage().contains("UNBOUNDED PRECEDING cannot be specified for the upper frame boundary")); + + String queryWithOffsetFollowingAsLowerBoundAndOffsetPrecedingAsUpperBound = + "SELECT col1, col2, SUM(col3) OVER (PARTITION BY col1 ORDER BY col2 ROWS BETWEEN 1 FOLLOWING AND 1 PRECEDING) " + + "FROM a"; + e = expectThrows(RuntimeException.class, + () -> _queryEnvironment.planQuery(queryWithOffsetFollowingAsLowerBoundAndOffsetPrecedingAsUpperBound)); + assertTrue(e.getCause().getMessage() + .contains("Upper frame boundary cannot be PRECEDING when lower boundary is FOLLOWING")); + + String queryWithValidBounds = + "SELECT col1, col2, SUM(col3) OVER (PARTITION BY col1 ORDER BY col2 ROWS BETWEEN 2 FOLLOWING AND 3 FOLLOWING) " + + "FROM a"; + _queryEnvironment.planQuery(queryWithValidBounds); + + // Custom RANGE window frame is not currently supported by Pinot + String sumQueryWithCustomRangeWindow = + "SELECT col1, col2, SUM(col3) OVER (PARTITION BY col1 ORDER BY col3 RANGE BETWEEN UNBOUNDED PRECEDING AND 1 " + + "FOLLOWING) FROM a"; + e = expectThrows(RuntimeException.class, () -> _queryEnvironment.planQuery(sumQueryWithCustomRangeWindow)); + assertTrue(e.getCause().getCause().getMessage() + .contains("RANGE window frame with offset PRECEDING / FOLLOWING is not supported")); + + // RANK, DENSE_RANK, ROW_NUMBER, LAG, LEAD with custom window frame are invalid + String rankQuery = + "SELECT col1, col2, RANK() OVER (PARTITION BY col1 ORDER BY col2 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) " + + "FROM a"; + e = expectThrows(RuntimeException.class, () -> _queryEnvironment.planQuery(rankQuery)); + assertTrue(e.getCause().getMessage().contains("ROW/RANGE not allowed")); + + String denseRankQuery = + "SELECT col1, col2, DENSE_RANK() OVER (PARTITION BY col1 ORDER BY col2 RANGE BETWEEN UNBOUNDED PRECEDING AND " + + "1 FOLLOWING) FROM a"; + e = expectThrows(RuntimeException.class, () -> _queryEnvironment.planQuery(denseRankQuery)); + assertTrue(e.getCause().getMessage().contains("ROW/RANGE not allowed")); + + String rowNumberQuery = + "SELECT col1, col2, ROW_NUMBER() OVER (PARTITION BY col1 ORDER BY col2 RANGE BETWEEN UNBOUNDED PRECEDING AND " + + "CURRENT ROW) FROM a"; + e = expectThrows(RuntimeException.class, () -> _queryEnvironment.planQuery(rowNumberQuery)); + assertTrue(e.getCause().getMessage().contains("ROW/RANGE not allowed")); + + String lagQuery = + "SELECT col1, col2, LAG(col2, 1) OVER (PARTITION BY col1 ORDER BY col2 ROWS BETWEEN UNBOUNDED PRECEDING AND " + + "UNBOUNDED FOLLOWING) FROM a"; + e = expectThrows(RuntimeException.class, () -> _queryEnvironment.planQuery(lagQuery)); + assertTrue(e.getCause().getMessage().contains("ROW/RANGE not allowed")); + + String leadQuery = + "SELECT col1, col2, LEAD(col2, 1) OVER (PARTITION BY col1 ORDER BY col2 RANGE BETWEEN CURRENT ROW AND " + + "UNBOUNDED FOLLOWING) FROM a"; + e = expectThrows(RuntimeException.class, () -> _queryEnvironment.planQuery(leadQuery)); + assertTrue(e.getCause().getMessage().contains("ROW/RANGE not allowed")); + } + // -------------------------------------------------------------------------- // Test Utils. // -------------------------------------------------------------------------- diff --git a/pinot-query-planner/src/test/resources/queries/WindowFunctionPlans.json b/pinot-query-planner/src/test/resources/queries/WindowFunctionPlans.json index ac8ef927843..452d75369a6 100644 --- a/pinot-query-planner/src/test/resources/queries/WindowFunctionPlans.json +++ b/pinot-query-planner/src/test/resources/queries/WindowFunctionPlans.json @@ -3565,14 +3565,14 @@ "exception_throwing_window_function_planning_tests": { "queries": [ { - "description": "unsupported custom frames - ORDER BY with two columns and RANGE", - "sql": "EXPLAIN PLAN FOR SELECT MIN(a.col3) OVER(ORDER BY a.col3, a.col1 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) FROM a", - "expectedException": ".*Upper bound must be CURRENT ROW but it is: UNBOUNDED FOLLOWING.*" + "description": "unsupported custom frames - ORDER BY with two columns and RANGE with offset bound", + "sql": "EXPLAIN PLAN FOR SELECT MIN(a.col3) OVER(ORDER BY a.col3, a.col1 RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING) FROM a", + "expectedException": ".*RANGE clause cannot be used with compound ORDER BY clause.*" }, { - "description": "unsupported custom frames - PARTITION BY and ORDER BY with two columns and RANGE", - "sql": "EXPLAIN PLAN FOR SELECT MIN(a.col3) OVER(PARTITION BY a.col2 ORDER BY a.col3, a.col1 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) FROM a", - "expectedException": ".*Upper bound must be CURRENT ROW but it is: UNBOUNDED FOLLOWING.*" + "description": "unsupported custom frames - PARTITION BY and ORDER BY with two columns and RANGE with offset bound", + "sql": "EXPLAIN PLAN FOR SELECT MIN(a.col3) OVER(PARTITION BY a.col2 ORDER BY a.col3, a.col1 RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING) FROM a", + "expectedException": ".*RANGE clause cannot be used with compound ORDER BY clause.*" }, { "description": "Using aggregation inside ORDER BY within OVER", diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperator.java index 7d97356eb2f..f60f771a199 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperator.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperator.java @@ -19,8 +19,6 @@ package org.apache.pinot.query.runtime.operator; import com.google.common.base.Preconditions; -import it.unimi.dsi.fastutil.ints.IntOpenHashSet; -import it.unimi.dsi.fastutil.ints.IntSet; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -44,6 +42,7 @@ import org.apache.pinot.query.runtime.operator.utils.TypeUtils; import org.apache.pinot.query.runtime.operator.window.WindowFunction; import org.apache.pinot.query.runtime.operator.window.WindowFunctionFactory; +import org.apache.pinot.query.runtime.operator.window.aggregate.WindowFrame; import org.apache.pinot.query.runtime.plan.OpChainExecutionContext; import org.apache.pinot.spi.utils.CommonConstants.MultiStageQueryRunner.WindowOverFlowMode; import org.slf4j.Logger; @@ -128,20 +127,20 @@ public WindowAggregateOperator(OpChainExecutionContext context, MultiStageOperat _keys[i] = keys.get(i); } _windowFrame = new WindowFrame(node.getWindowFrameType(), node.getLowerBound(), node.getUpperBound()); - Preconditions.checkState(_windowFrame.isUnboundedPreceding(), - "Only default frame is supported, lowerBound must be UNBOUNDED PRECEDING"); - Preconditions.checkState(_windowFrame.isUnboundedFollowing() || _windowFrame.isUpperBoundCurrentRow(), - "Only default frame is supported, upperBound must be UNBOUNDED FOLLOWING or CURRENT ROW"); + Preconditions.checkState( + _windowFrame.isRowType() || ((_windowFrame.isUnboundedPreceding() || _windowFrame.isLowerBoundCurrentRow()) && ( + _windowFrame.isUnboundedFollowing() || _windowFrame.isUpperBoundCurrentRow())), + "RANGE window frame with offset PRECEDING / FOLLOWING is not supported"); + Preconditions.checkState(_windowFrame.getLowerBound() <= _windowFrame.getUpperBound(), + "Window frame lower bound can't be greater than upper bound"); List collations = node.getCollations(); - boolean partitionByOnly = isPartitionByOnlyQuery(_keys, collations); List aggCalls = node.getAggCalls(); int numAggCalls = aggCalls.size(); _windowFunctions = new WindowFunction[numAggCalls]; for (int i = 0; i < numAggCalls; i++) { RexExpression.FunctionCall aggCall = aggCalls.get(i); - validateAggregationCalls(aggCall.getFunctionName()); _windowFunctions[i] = - WindowFunctionFactory.construnctWindowFunction(aggCall, inputSchema, collations, partitionByOnly); + WindowFunctionFactory.constructWindowFunction(aggCall, inputSchema, collations, _windowFrame); } Map metadata = context.getOpChainMetadata(); @@ -209,34 +208,6 @@ protected TransferableBlock getNextBlock() return computeBlocks(); } - private void validateAggregationCalls(String functionName) { - if (ROWS_ONLY_FUNCTION_NAMES.contains(functionName)) { - Preconditions.checkState( - _windowFrame._type == WindowNode.WindowFrameType.ROWS && _windowFrame.isUpperBoundCurrentRow(), - String.format("%s must be of ROW frame type and have CURRENT ROW as the upper bound", functionName)); - } else { - Preconditions.checkState(_windowFrame._type == WindowNode.WindowFrameType.RANGE, - String.format("Only RANGE type frames are supported at present for function: %s", functionName)); - } - } - - private boolean isPartitionByOnlyQuery(int[] keys, List collations) { - if (collations.isEmpty()) { - return true; - } - int numKeys = keys.length; - if (numKeys != collations.size()) { - return false; - } - IntSet keyIndices = new IntOpenHashSet(numKeys); - IntSet orderFieldIndices = new IntOpenHashSet(numKeys); - for (int i = 0; i < numKeys; i++) { - keyIndices.add(keys[i]); - orderFieldIndices.add(collations.get(i).getFieldIndex()); - } - return keyIndices.equals(orderFieldIndices); - } - /** * @return the final block, which must be either an end of stream or an error. */ @@ -313,37 +284,6 @@ private TransferableBlock computeBlocks() } } - /** - * Defines the Frame to be used for the window query. The 'lowerBound' and 'upperBound' indicate the frame - * boundaries to be used. Whereas, 'isRows' is used to differentiate between RANGE and ROWS type frames. - */ - private static class WindowFrame { - // Enum to denote the FRAME type, can be either ROW or RANGE types - final WindowNode.WindowFrameType _type; - // The lower bound of the frame. Set to Integer.MIN_VALUE if UNBOUNDED PRECEDING - final int _lowerBound; - // The lower bound of the frame. Set to Integer.MAX_VALUE if UNBOUNDED FOLLOWING. Set to 0 if CURRENT ROW - final int _upperBound; - - WindowFrame(WindowNode.WindowFrameType type, int lowerBound, int upperBound) { - _type = type; - _lowerBound = lowerBound; - _upperBound = upperBound; - } - - boolean isUnboundedPreceding() { - return _lowerBound == Integer.MIN_VALUE; - } - - boolean isUnboundedFollowing() { - return _upperBound == Integer.MAX_VALUE; - } - - boolean isUpperBoundCurrentRow() { - return _upperBound == 0; - } - } - public enum StatKey implements StatMap.Key { //@formatter:off EXECUTION_TIME_MS(StatMap.Type.LONG) { diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/WindowFunction.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/WindowFunction.java index 4501d638f45..6b5a1b81201 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/WindowFunction.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/WindowFunction.java @@ -24,6 +24,7 @@ import org.apache.pinot.query.planner.logical.RexExpression; import org.apache.pinot.query.runtime.operator.WindowAggregateOperator; import org.apache.pinot.query.runtime.operator.utils.AggregationUtils; +import org.apache.pinot.query.runtime.operator.window.aggregate.WindowFrame; /** @@ -33,18 +34,18 @@ */ public abstract class WindowFunction extends AggregationUtils.Accumulator { protected final int[] _orderKeys; - protected final boolean _partitionByOnly; protected final int[] _inputRefs; + protected final WindowFrame _windowFrame; public WindowFunction(RexExpression.FunctionCall aggCall, DataSchema inputSchema, List collations, - boolean partitionByOnly) { + WindowFrame windowFrame) { super(aggCall, inputSchema); int numOrderKeys = collations.size(); _orderKeys = new int[numOrderKeys]; for (int i = 0; i < numOrderKeys; i++) { _orderKeys[i] = collations.get(i).getFieldIndex(); } - _partitionByOnly = partitionByOnly; + _windowFrame = windowFrame; if (WindowAggregateOperator.RANKING_FUNCTION_NAMES.contains(aggCall.getFunctionName())) { _inputRefs = _orderKeys; } else { diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/WindowFunctionFactory.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/WindowFunctionFactory.java index 990524448c7..b794a6bcb85 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/WindowFunctionFactory.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/WindowFunctionFactory.java @@ -27,7 +27,8 @@ import org.apache.pinot.common.utils.DataSchema; import org.apache.pinot.query.planner.logical.RexExpression; import org.apache.pinot.query.runtime.operator.window.aggregate.AggregateWindowFunction; -import org.apache.pinot.query.runtime.operator.window.range.RangeWindowFunction; +import org.apache.pinot.query.runtime.operator.window.aggregate.WindowFrame; +import org.apache.pinot.query.runtime.operator.window.range.RankBasedWindowFunction; import org.apache.pinot.query.runtime.operator.window.value.ValueWindowFunction; @@ -38,20 +39,24 @@ public class WindowFunctionFactory { private WindowFunctionFactory() { } + //@formatter:off public static final Map> WINDOW_FUNCTION_MAP = - ImmutableMap.>builder().putAll(RangeWindowFunction.WINDOW_FUNCTION_MAP) - .putAll(ValueWindowFunction.WINDOW_FUNCTION_MAP).build(); + ImmutableMap.>builder() + .putAll(RankBasedWindowFunction.WINDOW_FUNCTION_MAP) + .putAll(ValueWindowFunction.WINDOW_FUNCTION_MAP) + .build(); + //@formatter:on - public static WindowFunction construnctWindowFunction(RexExpression.FunctionCall aggCall, DataSchema inputSchema, - List collations, boolean partitionByOnly) { + public static WindowFunction constructWindowFunction(RexExpression.FunctionCall aggCall, DataSchema inputSchema, + List collations, WindowFrame windowFrame) { String functionName = aggCall.getFunctionName(); Class windowFunctionClass = WINDOW_FUNCTION_MAP.getOrDefault(functionName, AggregateWindowFunction.class); try { Constructor constructor = windowFunctionClass.getConstructor(RexExpression.FunctionCall.class, DataSchema.class, List.class, - boolean.class); - return constructor.newInstance(aggCall, inputSchema, collations, partitionByOnly); + WindowFrame.class); + return constructor.newInstance(aggCall, inputSchema, collations, windowFrame); } catch (InstantiationException | IllegalAccessException | InvocationTargetException | NoSuchMethodException e) { throw new RuntimeException("Failed to instantiate WindowFunction for function: " + functionName, e); } diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/aggregate/AggregateWindowFunction.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/aggregate/AggregateWindowFunction.java index 6763542bd0d..9598ae5f385 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/aggregate/AggregateWindowFunction.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/aggregate/AggregateWindowFunction.java @@ -25,6 +25,7 @@ import java.util.List; import java.util.Map; import java.util.function.Function; +import javax.annotation.Nullable; import org.apache.calcite.rel.RelFieldCollation; import org.apache.pinot.common.utils.DataSchema; import org.apache.pinot.common.utils.DataSchema.ColumnDataType; @@ -39,8 +40,8 @@ public class AggregateWindowFunction extends WindowFunction { private final Merger _merger; public AggregateWindowFunction(RexExpression.FunctionCall aggCall, DataSchema inputSchema, - List collations, boolean partitionByOnly) { - super(aggCall, inputSchema, collations, partitionByOnly); + List collations, WindowFrame windowFrame) { + super(aggCall, inputSchema, collations, windowFrame); String functionName = aggCall.getFunctionName(); Function mergerCreator = AggregationUtils.Accumulator.MERGERS.get(functionName); Preconditions.checkArgument(mergerCreator != null, "Unsupported aggregate function: %s", functionName); @@ -49,82 +50,228 @@ public AggregateWindowFunction(RexExpression.FunctionCall aggCall, DataSchema in @Override public final List processRows(List rows) { - if (_partitionByOnly) { - return processPartitionOnlyRows(rows); + if (_windowFrame.isRowType()) { + return processRowsWindow(rows); } else { - return processRowsInternal(rows); + return processRangeWindow(rows); } } - protected List processPartitionOnlyRows(List rows) { + /** + * Process windows where both ends are unbounded. Both ROWS and RANGE windows can be processed similarly. + */ + private List processUnboundedPrecedingAndFollowingWindow(List rows) { + // Process all rows at once Object mergedResult = null; for (Object[] row : rows) { - Object value = _inputRef == -1 ? _literal : row[_inputRef]; - if (value == null) { - continue; + mergedResult = getMergedResult(mergedResult, row); + } + return Collections.nCopies(rows.size(), mergedResult); + } + + @Nullable + private Object getMergedResult(Object currentResult, Object[] row) { + Object value = _inputRef == -1 ? _literal : row[_inputRef]; + if (value == null) { + return currentResult; + } + if (currentResult == null) { + return _merger.init(value, _dataType); + } else { + return _merger.merge(currentResult, value); + } + } + + private List processRowsWindow(List rows) { + if (_windowFrame.isUnboundedPreceding() && _windowFrame.isUnboundedFollowing()) { + return processUnboundedPrecedingAndFollowingWindow(rows); + } + + if (_windowFrame.isUnboundedPreceding()) { + int upperBound = _windowFrame.getUpperBound(); + List results = new ArrayList<>(rows.size()); + Object mergedResult = null; + + // Calculate first window result + if (upperBound >= 0) { + for (int i = 0; i <= upperBound; i++) { + if (i < rows.size()) { + mergedResult = getMergedResult(mergedResult, rows.get(i)); + } else { + break; + } + } + } + + for (int i = 0; i < rows.size(); i++) { + results.add(mergedResult); + + // Update merged result for next row + if (i + upperBound + 1 < rows.size() && i + upperBound + 1 >= 0) { + mergedResult = getMergedResult(mergedResult, rows.get(i + upperBound + 1)); + } + } + + return results; + } + + if (_windowFrame.isUnboundedFollowing()) { + int lowerBound = _windowFrame.getLowerBound(); + List results = new ArrayList<>(rows.size()); + Object mergedResult = null; + + // Calculate last window result + if (lowerBound <= 0) { + for (int i = rows.size() - 1; i >= rows.size() - 1 + lowerBound; i--) { + if (i >= 0) { + mergedResult = getMergedResult(mergedResult, rows.get(i)); + } else { + break; + } + } + } + + for (int i = rows.size() - 1; i >= 0; i--) { + results.add(mergedResult); + + // Update merged result for next row + if (i + lowerBound - 1 < rows.size() && i + lowerBound - 1 >= 0) { + mergedResult = getMergedResult(mergedResult, rows.get(i + lowerBound - 1)); + } } - if (mergedResult == null) { - mergedResult = _merger.init(value, _dataType); + + Collections.reverse(results); + return results; + } + + int lowerBound = _windowFrame.getLowerBound(); + int upperBound = _windowFrame.getUpperBound(); + + // TODO: Optimize this to avoid recomputing the merged result for each window from scratch. We can use a simple + // sliding window algorithm for aggregations like SUM and COUNT. For MIN, MAX, etc. we'll need an additional + // structure like a deque or priority queue to keep track of the minimum / maximum values in the sliding window. + List results = new ArrayList<>(rows.size()); + for (int i = 0; i < rows.size(); i++) { + int lower; + int upper; + + if ((long) i + lowerBound >= rows.size()) { + // Fill rest of the rows with null since all subsequent windows will be out of bounds + for (int j = i; j < rows.size(); j++) { + results.add(null); + } + break; + } + lower = i + lowerBound; + + if ((long) i + upperBound >= rows.size()) { + upper = rows.size() - 1; } else { - mergedResult = _merger.merge(mergedResult, value); + upper = i + upperBound; + } + + Object mergedResult = null; + if (upper >= 0) { + for (int j = lower; j <= upper; j++) { + if (j >= 0 && j < rows.size()) { + mergedResult = getMergedResult(mergedResult, rows.get(j)); + } + if (j >= rows.size()) { + break; + } + } } + results.add(mergedResult); } - return Collections.nCopies(rows.size(), mergedResult); + return results; } - protected List processRowsInternal(List rows) { - Key emptyOrderKey = AggregationUtils.extractEmptyKey(); - OrderKeyResult orderByResult = new OrderKeyResult(); - for (Object[] row : rows) { - // Only need to accumulate the aggregate function values for RANGE type. ROW type can be calculated as - // we output the rows since the aggregation value depends on the neighboring rows. - Key orderKey = (_partitionByOnly && _orderKeys.length == 0) ? emptyOrderKey - : AggregationUtils.extractRowKey(row, _orderKeys); - - Key previousOrderKeyIfPresent = orderByResult.getPreviousOrderByKey(); - Object currentRes = - previousOrderKeyIfPresent == null ? null : orderByResult.getOrderByResults().get(previousOrderKeyIfPresent); - Object value = _inputRef == -1 ? _literal : row[_inputRef]; - if (currentRes == null) { - orderByResult.addOrderByResult(orderKey, _merger.init(value, _dataType)); - } else { - orderByResult.addOrderByResult(orderKey, _merger.merge(currentRes, value)); + private List processRangeWindow(List rows) { + // We don't currently support RANGE windows with offset FOLLOWING / PRECEDING and this is validated during planning + // so we can safely assume that the lower bound is either UNBOUNDED PRECEDING or CURRENT ROW and the upper bound + // is either UNBOUNDED FOLLOWING or CURRENT ROW. + + if (_windowFrame.isUnboundedPreceding() && _windowFrame.isUnboundedFollowing()) { + return processUnboundedPrecedingAndFollowingWindow(rows); + } + + KeyedResults orderByResult = new KeyedResults(); + if (_windowFrame.isUnboundedPreceding() && _windowFrame.isUpperBoundCurrentRow()) { + for (Object[] row : rows) { + Key orderKey = AggregationUtils.extractRowKey(row, _orderKeys); + Key previousOrderKeyIfPresent = orderByResult.getPreviousKey(); + Object currentRes = + previousOrderKeyIfPresent == null ? null : orderByResult.getKeyedResults().get(previousOrderKeyIfPresent); + Object value = _inputRef == -1 ? _literal : row[_inputRef]; + if (currentRes == null) { + orderByResult.addResult(orderKey, _merger.init(value, _dataType)); + } else { + orderByResult.addResult(orderKey, _merger.merge(currentRes, value)); + } + } + } else if (_windowFrame.isLowerBoundCurrentRow() && _windowFrame.isUnboundedFollowing()) { + // Do a reverse iteration + for (int i = rows.size() - 1; i >= 0; i--) { + Object[] row = rows.get(i); + Key orderKey = AggregationUtils.extractRowKey(row, _orderKeys); + Key previousOrderKeyIfPresent = orderByResult.getPreviousKey(); + Object currentRes = + previousOrderKeyIfPresent == null ? null : orderByResult.getKeyedResults().get(previousOrderKeyIfPresent); + Object value = _inputRef == -1 ? _literal : row[_inputRef]; + if (currentRes == null) { + orderByResult.addResult(orderKey, _merger.init(value, _dataType)); + } else { + orderByResult.addResult(orderKey, _merger.merge(currentRes, value)); + } + } + } else if (_windowFrame.isLowerBoundCurrentRow() && _windowFrame.isUpperBoundCurrentRow()) { + for (Object[] row : rows) { + Key orderKey = AggregationUtils.extractRowKey(row, _orderKeys); + Object currentRes = orderByResult.getKeyedResults().get(orderKey); + Object value = _inputRef == -1 ? _literal : row[_inputRef]; + if (currentRes == null) { + orderByResult.addResult(orderKey, _merger.init(value, _dataType)); + } else { + orderByResult.addResult(orderKey, _merger.merge(currentRes, value)); + } } + } else { + throw new IllegalStateException("RANGE window frame with offset PRECEDING / FOLLOWING is not supported"); } + List results = new ArrayList<>(rows.size()); for (Object[] row : rows) { - // Only need to accumulate the aggregate function values for RANGE type. ROW type can be calculated as - // we output the rows since the aggregation value depends on the neighboring rows. - Key orderKey = (_partitionByOnly && _orderKeys.length == 0) ? emptyOrderKey - : AggregationUtils.extractRowKey(row, _orderKeys); - Object value = orderByResult.getOrderByResults().get(orderKey); + Key orderKey = AggregationUtils.extractRowKey(row, _orderKeys); + Object value = orderByResult.getKeyedResults().get(orderKey); results.add(value); } return results; } - static class OrderKeyResult { - final Map _orderByResults; - Key _previousOrderByKey; + // Used to maintain running results for each key. Note that the key here is not the partition key but the key + // generated for a row based on the window's ORDER BY keys. + private static class KeyedResults { + final Map _keyedResults; + Key _previousKey; - OrderKeyResult() { - _orderByResults = new HashMap<>(); - _previousOrderByKey = null; + KeyedResults() { + _keyedResults = new HashMap<>(); + _previousKey = null; } - public void addOrderByResult(Key orderByKey, Object value) { + void addResult(Key key, Object value) { // We expect to get the rows in order based on the ORDER BY key so it is safe to blindly assign the // current key as the previous key - _orderByResults.put(orderByKey, value); - _previousOrderByKey = orderByKey; + _keyedResults.put(key, value); + _previousKey = key; } - public Map getOrderByResults() { - return _orderByResults; + Map getKeyedResults() { + return _keyedResults; } - public Key getPreviousOrderByKey() { - return _previousOrderByKey; + Key getPreviousKey() { + return _previousKey; } } } diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/aggregate/WindowFrame.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/aggregate/WindowFrame.java new file mode 100644 index 00000000000..0199cdac555 --- /dev/null +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/aggregate/WindowFrame.java @@ -0,0 +1,75 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.query.runtime.operator.window.aggregate; + +import org.apache.pinot.query.planner.plannode.WindowNode; + + +/** + * Defines the window frame to be used for a window function. The 'lowerBound' and 'upperBound' indicate the frame + * boundaries to be used. The frame can be of two types: ROW or RANGE. + */ +public class WindowFrame { + // Enum to denote the FRAME type, can be either ROW or RANGE types + private final WindowNode.WindowFrameType _type; + // Both these bounds are relative to current row; 0 means current row, -1 means previous row, 1 means next row, etc. + // Integer.MIN_VALUE represents UNBOUNDED PRECEDING which is only allowed for the lower bound (ensured by Calcite). + // Integer.MAX_VALUE represents UNBOUNDED FOLLOWING which is only allowed for the upper bound (ensured by Calcite). + private final int _lowerBound; + private final int _upperBound; + + public WindowFrame(WindowNode.WindowFrameType type, int lowerBound, int upperBound) { + _type = type; + _lowerBound = lowerBound; + _upperBound = upperBound; + } + + public boolean isUnboundedPreceding() { + return _lowerBound == Integer.MIN_VALUE; + } + + public boolean isUnboundedFollowing() { + return _upperBound == Integer.MAX_VALUE; + } + + public boolean isLowerBoundCurrentRow() { + return _lowerBound == 0; + } + + public boolean isUpperBoundCurrentRow() { + return _upperBound == 0; + } + + public boolean isRowType() { + return _type == WindowNode.WindowFrameType.ROWS; + } + + public int getLowerBound() { + return _lowerBound; + } + + public int getUpperBound() { + return _upperBound; + } + + @Override + public String toString() { + return "WindowFrame{" + "type=" + _type + ", lowerBound=" + _lowerBound + ", upperBound=" + _upperBound + '}'; + } +} diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/DenseRankWindowFunction.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/DenseRankWindowFunction.java index 02b3993fa0f..a92c2f606a0 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/DenseRankWindowFunction.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/DenseRankWindowFunction.java @@ -23,13 +23,14 @@ import org.apache.calcite.rel.RelFieldCollation; import org.apache.pinot.common.utils.DataSchema; import org.apache.pinot.query.planner.logical.RexExpression; +import org.apache.pinot.query.runtime.operator.window.aggregate.WindowFrame; -public class DenseRankWindowFunction extends RangeWindowFunction { +public class DenseRankWindowFunction extends RankBasedWindowFunction { public DenseRankWindowFunction(RexExpression.FunctionCall aggCall, DataSchema inputSchema, - List collations, boolean partitionByOnly) { - super(aggCall, inputSchema, collations, partitionByOnly); + List collations, WindowFrame windowFrame) { + super(aggCall, inputSchema, collations, windowFrame); } @Override diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/RangeWindowFunction.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/RankBasedWindowFunction.java similarity index 78% rename from pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/RangeWindowFunction.java rename to pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/RankBasedWindowFunction.java index ab1994dd28e..da16bbfe9eb 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/RangeWindowFunction.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/RankBasedWindowFunction.java @@ -25,22 +25,27 @@ import org.apache.pinot.common.utils.DataSchema; import org.apache.pinot.query.planner.logical.RexExpression; import org.apache.pinot.query.runtime.operator.window.WindowFunction; +import org.apache.pinot.query.runtime.operator.window.aggregate.WindowFrame; -public abstract class RangeWindowFunction extends WindowFunction { +/** + * Rank based window functions don't support custom window frames (ROWS / RANGE) and are computed over the + * entire partition. Calcite enforces that a custom window frame cannot be specified for these functions. + */ +public abstract class RankBasedWindowFunction extends WindowFunction { //@formatter:off public static final Map> WINDOW_FUNCTION_MAP = ImmutableMap.>builder() - // Range window functions + // Rank based window functions .put("ROW_NUMBER", RowNumberWindowFunction.class) .put("RANK", RankWindowFunction.class) .put("DENSE_RANK", DenseRankWindowFunction.class) .build(); //@formatter:on - public RangeWindowFunction(RexExpression.FunctionCall aggCall, DataSchema inputSchema, - List collations, boolean partitionByOnly) { - super(aggCall, inputSchema, collations, partitionByOnly); + public RankBasedWindowFunction(RexExpression.FunctionCall aggCall, DataSchema inputSchema, + List collations, WindowFrame windowFrame) { + super(aggCall, inputSchema, collations, windowFrame); } protected int compareRows(Object[] leftRow, Object[] rightRow) { diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/RankWindowFunction.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/RankWindowFunction.java index b6911089e48..a0bc1aea967 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/RankWindowFunction.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/RankWindowFunction.java @@ -23,13 +23,14 @@ import org.apache.calcite.rel.RelFieldCollation; import org.apache.pinot.common.utils.DataSchema; import org.apache.pinot.query.planner.logical.RexExpression; +import org.apache.pinot.query.runtime.operator.window.aggregate.WindowFrame; -public class RankWindowFunction extends RangeWindowFunction { +public class RankWindowFunction extends RankBasedWindowFunction { public RankWindowFunction(RexExpression.FunctionCall aggCall, DataSchema inputSchema, - List collations, boolean partitionByOnly) { - super(aggCall, inputSchema, collations, partitionByOnly); + List collations, WindowFrame windowFrame) { + super(aggCall, inputSchema, collations, windowFrame); } @Override diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/RowNumberWindowFunction.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/RowNumberWindowFunction.java index a306af4706a..cb5821691f1 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/RowNumberWindowFunction.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/RowNumberWindowFunction.java @@ -23,13 +23,14 @@ import org.apache.calcite.rel.RelFieldCollation; import org.apache.pinot.common.utils.DataSchema; import org.apache.pinot.query.planner.logical.RexExpression; +import org.apache.pinot.query.runtime.operator.window.aggregate.WindowFrame; -public class RowNumberWindowFunction extends RangeWindowFunction { +public class RowNumberWindowFunction extends RankBasedWindowFunction { public RowNumberWindowFunction(RexExpression.FunctionCall aggCall, DataSchema inputSchema, - List collations, boolean partitionByOnly) { - super(aggCall, inputSchema, collations, partitionByOnly); + List collations, WindowFrame windowFrame) { + super(aggCall, inputSchema, collations, windowFrame); } @Override diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/FirstValueWindowFunction.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/FirstValueWindowFunction.java index 34e257451a9..32bc3cb28e3 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/FirstValueWindowFunction.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/FirstValueWindowFunction.java @@ -18,22 +18,87 @@ */ package org.apache.pinot.query.runtime.operator.window.value; +import com.google.common.base.Preconditions; +import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.apache.calcite.rel.RelFieldCollation; import org.apache.pinot.common.utils.DataSchema; +import org.apache.pinot.core.data.table.Key; import org.apache.pinot.query.planner.logical.RexExpression; +import org.apache.pinot.query.runtime.operator.utils.AggregationUtils; +import org.apache.pinot.query.runtime.operator.window.aggregate.WindowFrame; public class FirstValueWindowFunction extends ValueWindowFunction { public FirstValueWindowFunction(RexExpression.FunctionCall aggCall, DataSchema inputSchema, - List collations, boolean partitionByOnly) { - super(aggCall, inputSchema, collations, partitionByOnly); + List collations, WindowFrame windowFrame) { + super(aggCall, inputSchema, collations, windowFrame); } @Override public List processRows(List rows) { + if (_windowFrame.isRowType()) { + return processRowsWindow(rows); + } else { + return processRangeWindow(rows); + } + } + + private List processRowsWindow(List rows) { + if (_windowFrame.isUnboundedPreceding() && _windowFrame.getUpperBound() >= 0) { + return processUnboundedPreceding(rows); + } + + int numRows = rows.size(); + List result = new ArrayList<>(numRows); + + // lowerBound is guaranteed to be less than or equal to upperBound here (but both can be -ve / +ve) + int lowerBound = _windowFrame.getLowerBound(); + int upperBound = _windowFrame.getUpperBound(); + + for (int i = 0; i < numRows; i++) { + // We want to make sure to avoid overflows + long lower = (long) lowerBound + i; + long upper = (long) upperBound + i; + + if (lower >= rows.size() || upper < 0) { + result.add(null); + continue; + } + + result.add(extractValueFromRow(rows.get(Math.max(0, (int) lower)))); + } + return result; + } + + private List processRangeWindow(List rows) { + if (_windowFrame.isUnboundedPreceding()) { + return processUnboundedPreceding(rows); + } + + // The lower bound has to be CURRENT ROW since we don't support RANGE windows with offset value + Preconditions.checkState(_windowFrame.isLowerBoundCurrentRow(), + "RANGE window frame with offset PRECEDING / FOLLOWING is not supported"); + + int numRows = rows.size(); + List result = new ArrayList<>(numRows); + + Map firstValueForKey = new HashMap<>(); + for (Object[] row : rows) { + Key orderKey = AggregationUtils.extractRowKey(row, _orderKeys); + Object value = extractValueFromRow(row); + Object prev = firstValueForKey.putIfAbsent(orderKey, value); + result.add(prev != null ? prev : value); + } + + return result; + } + + private List processUnboundedPreceding(List rows) { int numRows = rows.size(); assert numRows > 0; Object value = extractValueFromRow(rows.get(0)); diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LagValueWindowFunction.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LagValueWindowFunction.java index 797fca9313b..013cdd77c40 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LagValueWindowFunction.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LagValueWindowFunction.java @@ -25,15 +25,19 @@ import org.apache.pinot.common.utils.DataSchema; import org.apache.pinot.common.utils.PinotDataType; import org.apache.pinot.query.planner.logical.RexExpression; +import org.apache.pinot.query.runtime.operator.window.aggregate.WindowFrame; +/** + * The LAG window function doesn't allow custom window frames (and this is enforced by Calcite). + */ public class LagValueWindowFunction extends ValueWindowFunction { private final int _offset; private final Object _defaultValue; public LagValueWindowFunction(RexExpression.FunctionCall aggCall, DataSchema inputSchema, - List collations, boolean partitionByOnly) { - super(aggCall, inputSchema, collations, partitionByOnly); + List collations, WindowFrame windowFrame) { + super(aggCall, inputSchema, collations, windowFrame); int offset = 1; Object defaultValue = null; List operands = aggCall.getFunctionOperands(); diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LastValueWindowFunction.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LastValueWindowFunction.java index 7c8c4b2260f..5f45c8ea3f1 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LastValueWindowFunction.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LastValueWindowFunction.java @@ -18,22 +18,91 @@ */ package org.apache.pinot.query.runtime.operator.window.value; +import com.google.common.base.Preconditions; +import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.apache.calcite.rel.RelFieldCollation; import org.apache.pinot.common.utils.DataSchema; +import org.apache.pinot.core.data.table.Key; import org.apache.pinot.query.planner.logical.RexExpression; +import org.apache.pinot.query.runtime.operator.utils.AggregationUtils; +import org.apache.pinot.query.runtime.operator.window.aggregate.WindowFrame; public class LastValueWindowFunction extends ValueWindowFunction { public LastValueWindowFunction(RexExpression.FunctionCall aggCall, DataSchema inputSchema, - List collations, boolean partitionByOnly) { - super(aggCall, inputSchema, collations, partitionByOnly); + List collations, WindowFrame windowFrame) { + super(aggCall, inputSchema, collations, windowFrame); } @Override public List processRows(List rows) { + if (_windowFrame.isRowType()) { + return processRowsWindow(rows); + } else { + return processRangeWindow(rows); + } + } + + private List processRowsWindow(List rows) { + if (_windowFrame.isUnboundedFollowing() && _windowFrame.getLowerBound() <= 0) { + return processUnboundedFollowing(rows); + } + + int numRows = rows.size(); + List result = new ArrayList<>(numRows); + + // lowerBound is guaranteed to be less than or equal to upperBound here (but both can be -ve / +ve) + int lowerBound = _windowFrame.getLowerBound(); + int upperBound = _windowFrame.getUpperBound(); + + for (int i = 0; i < numRows; i++) { + // We want to make sure to avoid overflows + long lower = (long) lowerBound + i; + long upper = (long) upperBound + i; + + if (lower >= rows.size() || upper < 0) { + result.add(null); + continue; + } + + result.add(extractValueFromRow(rows.get((int) Math.min(upper, rows.size() - 1)))); + } + + return result; + } + + private List processRangeWindow(List rows) { + if (_windowFrame.isUnboundedFollowing()) { + return processUnboundedFollowing(rows); + } + + // The upper bound has to be CURRENT ROW here since we don't support RANGE windows with offset value + Preconditions.checkState(_windowFrame.isUpperBoundCurrentRow(), + "RANGE window frame with offset PRECEDING / FOLLOWING is not supported"); + + int numRows = rows.size(); + List result = new ArrayList<>(numRows); + Map lastValueForKey = new HashMap<>(); + + for (int i = numRows - 1; i >= 0; i--) { + Object[] row = rows.get(i); + Key orderKey = AggregationUtils.extractRowKey(row, _orderKeys); + Object value = extractValueFromRow(row); + Object prev = lastValueForKey.putIfAbsent(orderKey, value); + result.add(prev != null ? prev : value); + } + + Collections.reverse(result); + return result; + } + + private List processUnboundedFollowing(List rows) { int numRows = rows.size(); assert numRows > 0; Object value = extractValueFromRow(rows.get(numRows - 1)); diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LeadValueWindowFunction.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LeadValueWindowFunction.java index 099c3fba5f9..b81913f6c0e 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LeadValueWindowFunction.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LeadValueWindowFunction.java @@ -25,16 +25,20 @@ import org.apache.pinot.common.utils.DataSchema; import org.apache.pinot.common.utils.PinotDataType; import org.apache.pinot.query.planner.logical.RexExpression; +import org.apache.pinot.query.runtime.operator.window.aggregate.WindowFrame; +/** + * The LAG window function doesn't allow custom window frames (and this is enforced by Calcite). + */ public class LeadValueWindowFunction extends ValueWindowFunction { private final int _offset; private final Object _defaultValue; public LeadValueWindowFunction(RexExpression.FunctionCall aggCall, DataSchema inputSchema, - List collations, boolean partitionByOnly) { - super(aggCall, inputSchema, collations, partitionByOnly); + List collations, WindowFrame windowFrame) { + super(aggCall, inputSchema, collations, windowFrame); int offset = 1; Object defaultValue = null; List operands = aggCall.getFunctionOperands(); diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/ValueWindowFunction.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/ValueWindowFunction.java index a05c12e2486..3f8a1a8422b 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/ValueWindowFunction.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/ValueWindowFunction.java @@ -25,6 +25,7 @@ import org.apache.pinot.common.utils.DataSchema; import org.apache.pinot.query.planner.logical.RexExpression; import org.apache.pinot.query.runtime.operator.window.WindowFunction; +import org.apache.pinot.query.runtime.operator.window.aggregate.WindowFrame; public abstract class ValueWindowFunction extends WindowFunction { @@ -40,8 +41,8 @@ public abstract class ValueWindowFunction extends WindowFunction { //@formatter:on public ValueWindowFunction(RexExpression.FunctionCall aggCall, DataSchema inputSchema, - List collations, boolean partitionByOnly) { - super(aggCall, inputSchema, collations, partitionByOnly); + List collations, WindowFrame windowFrame) { + super(aggCall, inputSchema, collations, windowFrame); } protected Object extractValueFromRow(Object[] row) { diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java index ec5d667254b..07057a80ce5 100644 --- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java +++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java @@ -38,14 +38,18 @@ import org.apache.pinot.query.runtime.blocks.TransferableBlockTestUtils; import org.apache.pinot.query.runtime.blocks.TransferableBlockUtils; import org.mockito.Mock; +import org.testng.Assert; import org.testng.annotations.AfterMethod; import org.testng.annotations.BeforeMethod; +import org.testng.annotations.DataProvider; import org.testng.annotations.Test; import static org.apache.pinot.common.utils.DataSchema.ColumnDataType.DOUBLE; import static org.apache.pinot.common.utils.DataSchema.ColumnDataType.INT; import static org.apache.pinot.common.utils.DataSchema.ColumnDataType.LONG; import static org.apache.pinot.common.utils.DataSchema.ColumnDataType.STRING; +import static org.apache.pinot.query.planner.plannode.WindowNode.WindowFrameType.RANGE; +import static org.apache.pinot.query.planner.plannode.WindowNode.WindowFrameType.ROWS; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -323,7 +327,7 @@ public void testRowNumberRankingFunction() { List aggCalls = List.of(new RexExpression.FunctionCall(ColumnDataType.INT, SqlKind.ROW_NUMBER.name(), List.of())); WindowAggregateOperator operator = - getOperator(inputSchema, resultSchema, keys, collations, aggCalls, WindowNode.WindowFrameType.ROWS, + getOperator(inputSchema, resultSchema, keys, collations, aggCalls, ROWS, Integer.MIN_VALUE, 0); // When: @@ -351,9 +355,10 @@ public void testNonEmptyOrderByKeysNotMatchingPartitionByKeys() { List collations = List.of(new RelFieldCollation(1, RelFieldCollation.Direction.ASCENDING, RelFieldCollation.NullDirection.LAST)); List aggCalls = List.of(getSum(new RexExpression.InputRef(0))); + // RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW (default window frame for ORDER BY) WindowAggregateOperator operator = getOperator(inputSchema, resultSchema, keys, collations, aggCalls, WindowNode.WindowFrameType.RANGE, - Integer.MIN_VALUE, Integer.MAX_VALUE); + Integer.MIN_VALUE, 0); // When: List resultRows = operator.nextBlock().getContainer(); @@ -420,59 +425,6 @@ public void testNonEmptyOrderByKeysMatchingPartitionByKeysWithDifferentDirection assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)"); } - @Test(expectedExceptions = IllegalStateException.class, expectedExceptionsMessageRegExp = "Only RANGE type frames " - + "are supported at present.*") - public void testShouldThrowOnInvalidRowsFunction() { - // Given: - DataSchema inputSchema = new DataSchema(new String[]{"group", "arg"}, new ColumnDataType[]{INT, STRING}); - when(_input.nextBlock()).thenReturn(OperatorTestUtil.block(inputSchema, new Object[]{2, "foo"})) - .thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0)); - DataSchema resultSchema = - new DataSchema(new String[]{"group", "arg", "sum"}, new ColumnDataType[]{INT, STRING, DOUBLE}); - List keys = List.of(0); - List aggCalls = List.of(getSum(new RexExpression.InputRef(1))); - - // Then: - getOperator(inputSchema, resultSchema, keys, List.of(), aggCalls, WindowNode.WindowFrameType.ROWS, - Integer.MIN_VALUE, Integer.MAX_VALUE); - } - - @Test(expectedExceptions = IllegalStateException.class, expectedExceptionsMessageRegExp = "Only default frame is " - + "supported, lowerBound must be UNBOUNDED PRECEDING") - public void testShouldThrowOnCustomFramesCustomPreceding() { - // TODO: Remove this test once custom frame support is added - // Given: - DataSchema inputSchema = new DataSchema(new String[]{"group", "arg"}, new ColumnDataType[]{INT, STRING}); - when(_input.nextBlock()).thenReturn(OperatorTestUtil.block(inputSchema, new Object[]{2, "foo"})) - .thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0)); - DataSchema resultSchema = - new DataSchema(new String[]{"group", "arg", "sum"}, new ColumnDataType[]{INT, STRING, DOUBLE}); - List keys = List.of(0); - List aggCalls = List.of(getSum(new RexExpression.InputRef(1))); - - // Then: - getOperator(inputSchema, resultSchema, keys, List.of(), aggCalls, WindowNode.WindowFrameType.RANGE, 5, - Integer.MAX_VALUE); - } - - @Test(expectedExceptions = IllegalStateException.class, expectedExceptionsMessageRegExp = "Only default frame is " - + "supported, upperBound must be UNBOUNDED FOLLOWING or CURRENT ROW") - public void testShouldThrowOnCustomFramesCustomFollowing() { - // TODO: Remove this test once custom frame support is added - // Given: - DataSchema inputSchema = new DataSchema(new String[]{"group", "arg"}, new ColumnDataType[]{INT, STRING}); - when(_input.nextBlock()).thenReturn(OperatorTestUtil.block(inputSchema, new Object[]{2, "foo"})) - .thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0)); - DataSchema resultSchema = - new DataSchema(new String[]{"group", "arg", "sum"}, new ColumnDataType[]{INT, STRING, DOUBLE}); - List keys = List.of(0); - List aggCalls = List.of(getSum(new RexExpression.InputRef(1))); - - // Then: - getOperator(inputSchema, resultSchema, keys, List.of(), aggCalls, WindowNode.WindowFrameType.RANGE, - Integer.MIN_VALUE, 5); - } - @Test public void testShouldReturnErrorBlockOnUnexpectedInputType() { // Given: @@ -567,16 +519,16 @@ public void testLeadLagWindowFunction() { OperatorTestUtil.block(inputSchema, new Object[]{1, "foo"}, new Object[]{2, "foo"}, new Object[]{1, "numb"}, new Object[]{2, "the"}, new Object[]{3, "true"})) .thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0)); - DataSchema resultSchema = new DataSchema(new String[]{"group", "arg", "lead", "lag"}, - new ColumnDataType[]{INT, STRING, INT, INT}); + DataSchema resultSchema = + new DataSchema(new String[]{"group", "arg", "lead", "lag"}, new ColumnDataType[]{INT, STRING, INT, INT}); List keys = List.of(0); List collations = List.of(new RelFieldCollation(1, RelFieldCollation.Direction.ASCENDING, RelFieldCollation.NullDirection.LAST)); - List aggCalls = - List.of(new RexExpression.FunctionCall(ColumnDataType.INT, SqlKind.LEAD.name(), - List.of(new RexExpression.InputRef(0), new RexExpression.Literal(ColumnDataType.INT, 1))), - new RexExpression.FunctionCall(ColumnDataType.INT, SqlKind.LAG.name(), - List.of(new RexExpression.InputRef(0), new RexExpression.Literal(ColumnDataType.INT, 1)))); + List aggCalls = List.of( + new RexExpression.FunctionCall(ColumnDataType.INT, SqlKind.LEAD.name(), + List.of(new RexExpression.InputRef(0), new RexExpression.Literal(ColumnDataType.INT, 1))), + new RexExpression.FunctionCall(ColumnDataType.INT, SqlKind.LAG.name(), + List.of(new RexExpression.InputRef(0), new RexExpression.Literal(ColumnDataType.INT, 1)))); WindowAggregateOperator operator = getOperator(inputSchema, resultSchema, keys, collations, aggCalls, WindowNode.WindowFrameType.RANGE, Integer.MIN_VALUE, 0); @@ -584,6 +536,7 @@ public void testLeadLagWindowFunction() { // When: List resultRows = operator.nextBlock().getContainer(); // Then: + //@formatter:off verifyResultRows(resultRows, keys, Map.of( 1, List.of( new Object[]{1, "foo", 1, null}, @@ -598,6 +551,7 @@ public void testLeadLagWindowFunction() { new Object[]{3, "and", 3, null}, new Object[]{3, "true", null, 3}) )); + //@formatter:on assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)"); } @@ -611,18 +565,18 @@ public void testLeadLagWindowFunction2() { OperatorTestUtil.block(inputSchema, new Object[]{1, "foo"}, new Object[]{2, "foo"}, new Object[]{1, "numb"}, new Object[]{2, "the"}, new Object[]{3, "true"})) .thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0)); - DataSchema resultSchema = new DataSchema(new String[]{"group", "arg", "lead", "lag"}, - new ColumnDataType[]{INT, STRING, INT, INT}); + DataSchema resultSchema = + new DataSchema(new String[]{"group", "arg", "lead", "lag"}, new ColumnDataType[]{INT, STRING, INT, INT}); List keys = List.of(0); List collations = List.of(new RelFieldCollation(1, RelFieldCollation.Direction.ASCENDING, RelFieldCollation.NullDirection.LAST)); - List aggCalls = - List.of(new RexExpression.FunctionCall(ColumnDataType.INT, SqlKind.LEAD.name(), - List.of(new RexExpression.InputRef(0), new RexExpression.Literal(ColumnDataType.INT, 2), - new RexExpression.Literal(ColumnDataType.INT, 100))), - new RexExpression.FunctionCall(ColumnDataType.INT, SqlKind.LAG.name(), - List.of(new RexExpression.InputRef(0), new RexExpression.Literal(ColumnDataType.INT, 1), - new RexExpression.Literal(ColumnDataType.INT, 200)))); + List aggCalls = List.of( + new RexExpression.FunctionCall(ColumnDataType.INT, SqlKind.LEAD.name(), + List.of(new RexExpression.InputRef(0), new RexExpression.Literal(ColumnDataType.INT, 2), + new RexExpression.Literal(ColumnDataType.INT, 100))), + new RexExpression.FunctionCall(ColumnDataType.INT, SqlKind.LAG.name(), + List.of(new RexExpression.InputRef(0), new RexExpression.Literal(ColumnDataType.INT, 1), + new RexExpression.Literal(ColumnDataType.INT, 200)))); WindowAggregateOperator operator = getOperator(inputSchema, resultSchema, keys, collations, aggCalls, WindowNode.WindowFrameType.RANGE, Integer.MIN_VALUE, 0); @@ -630,6 +584,7 @@ public void testLeadLagWindowFunction2() { // When: List resultRows = operator.nextBlock().getContainer(); // Then: + //@formatter:off verifyResultRows(resultRows, keys, Map.of( 1, List.of( new Object[]{1, "foo", 1, 200}, @@ -644,9 +599,1340 @@ public void testLeadLagWindowFunction2() { new Object[]{3, "and", 100, 200}, new Object[]{3, "true", 100, 3}) )); + //@formatter:on + assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)"); + } + + @Test(dataProvider = "windowFrameTypes") + public void testSumWithUnboundedPrecedingLowerAndUnboundedFollowingUpper(WindowNode.WindowFrameType frameType) { + // Given: + //@formatter:off + WindowAggregateOperator operator = prepareDataForWindowFunction(new String[]{"name", "value", "year"}, + new ColumnDataType[]{STRING, INT, INT}, DOUBLE, List.of(0), 2, frameType, Integer.MIN_VALUE, Integer.MAX_VALUE, + getSum(new RexExpression.InputRef(1)), + new Object[][]{ + new Object[]{"A", 14, 2000}, + new Object[]{"A", 10, 2002}, + new Object[]{"A", 20, 2008}, + new Object[]{"A", 15, 2008}, + new Object[]{"B", 10, 2000}, + new Object[]{"B", 20, 2005} + }); + //@formatter:on + + // When: + List resultRows = operator.nextBlock().getContainer(); + + // Then (result should be the same for both window frame types): + //@formatter:off + verifyResultRows(resultRows, List.of(0), Map.of( + "A", List.of( + new Object[]{"A", 14, 2000, 59.0}, + new Object[]{"A", 10, 2002, 59.0}, + new Object[]{"A", 20, 2008, 59.0}, + new Object[]{"A", 15, 2008, 59.0} + ), + "B", List.of( + new Object[]{"B", 10, 2000, 30.0}, + new Object[]{"B", 20, 2005, 30.0} + ))); + //@formatter:on + assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)"); + } + + @Test(dataProvider = "windowFrameTypes") + public void testSumWithUnboundedPrecedingLowerAndCurrentRowUpper(WindowNode.WindowFrameType frameType) { + // Given: + //@formatter:off + WindowAggregateOperator operator = prepareDataForWindowFunction(new String[]{"name", "value", "year"}, + new ColumnDataType[]{STRING, INT, INT}, DOUBLE, List.of(0), 2, frameType, Integer.MIN_VALUE, 0, + getSum(new RexExpression.InputRef(1)), + new Object[][]{ + new Object[]{"A", 14, 2000}, + new Object[]{"A", 10, 2002}, + new Object[]{"A", 20, 2008}, + new Object[]{"A", 15, 2008}, + new Object[]{"B", 10, 2000}, + new Object[]{"B", 20, 2005} + }); + //@formatter:on + + // When: + List resultRows = operator.nextBlock().getContainer(); + + // Then: + //@formatter:off + verifyResultRows(resultRows, List.of(0), Map.of( + "A", List.of( + new Object[]{"A", 14, 2000, 14.0}, + new Object[]{"A", 10, 2002, 24.0}, + new Object[]{"A", 20, 2008, frameType == ROWS ? 44.0 : 59.0}, + new Object[]{"A", 15, 2008, 59.0} + ), + "B", List.of( + new Object[]{"B", 10, 2000, 10.0}, + new Object[]{"B", 20, 2005, 30.0} + ))); + //@formatter:on + assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)"); + } + + @Test + public void testSumWithUnboundedPrecedingLowerAndOffsetFollowingUpper() { + // Given: + //@formatter:off + WindowAggregateOperator operator = prepareDataForWindowFunction(new String[]{"name", "value", "year"}, + new ColumnDataType[]{STRING, INT, INT}, DOUBLE, List.of(0), 2, ROWS, Integer.MIN_VALUE, 2, + getSum(new RexExpression.InputRef(1)), + new Object[][]{ + new Object[]{"A", 14, 2000}, + new Object[]{"A", 10, 2002}, + new Object[]{"A", 20, 2008}, + new Object[]{"A", 15, 2008}, + new Object[]{"B", 10, 2000}, + new Object[]{"B", 20, 2005} + }); + //@formatter:on + + // When: + List resultRows = operator.nextBlock().getContainer(); + + // Then: + //@formatter:off + verifyResultRows(resultRows, List.of(0), Map.of( + "A", List.of( + new Object[]{"A", 14, 2000, 44.0}, + new Object[]{"A", 10, 2002, 59.0}, + new Object[]{"A", 20, 2008, 59.0}, + new Object[]{"A", 15, 2008, 59.0} + ), + "B", List.of( + new Object[]{"B", 10, 2000, 30.0}, + new Object[]{"B", 20, 2005, 30.0} + ))); + //@formatter:on + assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)"); + } + + @Test + public void testSumWithUnboundedPrecedingLowerAndOffsetPrecedingUpper() { + // Given: + //@formatter:off + WindowAggregateOperator operator = prepareDataForWindowFunction(new String[]{"name", "value", "year"}, + new ColumnDataType[]{STRING, INT, INT}, DOUBLE, List.of(0), 2, ROWS, Integer.MIN_VALUE, -2, + getSum(new RexExpression.InputRef(1)), + new Object[][]{ + new Object[]{"A", 14, 2000}, + new Object[]{"A", 10, 2002}, + new Object[]{"A", 20, 2008}, + new Object[]{"A", 15, 2008}, + new Object[]{"B", 10, 2000}, + new Object[]{"B", 20, 2005} + }); + //@formatter:on + + // When: + List resultRows = operator.nextBlock().getContainer(); + + // Then: + //@formatter:off + verifyResultRows(resultRows, List.of(0), Map.of( + "A", List.of( + new Object[]{"A", 14, 2000, null}, + new Object[]{"A", 10, 2002, null}, + new Object[]{"A", 20, 2008, 14.0}, + new Object[]{"A", 15, 2008, 24.0} + ), + "B", List.of( + new Object[]{"B", 10, 2000, null}, + new Object[]{"B", 20, 2005, null} + ))); + //@formatter:on + assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)"); + } + + @Test(dataProvider = "windowFrameTypes") + public void testSumWithCurrentRowLowerAndUnboundedFollowingUpper(WindowNode.WindowFrameType frameType) { + // Given: + //@formatter:off + WindowAggregateOperator operator = prepareDataForWindowFunction(new String[]{"name", "value", "year"}, + new ColumnDataType[]{STRING, INT, INT}, DOUBLE, List.of(0), 2, frameType, 0, Integer.MAX_VALUE, + getSum(new RexExpression.InputRef(1)), + new Object[][]{ + new Object[]{"A", 14, 2000}, + new Object[]{"A", 10, 2002}, + new Object[]{"A", 20, 2008}, + new Object[]{"A", 15, 2008}, + new Object[]{"B", 10, 2000}, + new Object[]{"B", 20, 2005} + }); + //@formatter:on + + // When: + List resultRows = operator.nextBlock().getContainer(); + + // Then: + //@formatter:off + verifyResultRows(resultRows, List.of(0), Map.of( + "A", List.of( + new Object[]{"A", 14, 2000, 59.0}, + new Object[]{"A", 10, 2002, 45.0}, + new Object[]{"A", 20, 2008, 35.0}, + new Object[]{"A", 15, 2008, frameType == ROWS ? 15.0 : 35.0} + ), + "B", List.of( + new Object[]{"B", 10, 2000, 30.0}, + new Object[]{"B", 20, 2005, 20.0} + ))); + //@formatter:on + assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)"); + } + + @Test(dataProvider = "windowFrameTypes") + public void testSumWithCurrentRowLowerAndCurrentRowUpper(WindowNode.WindowFrameType frameType) { + // Given: + //@formatter:off + WindowAggregateOperator operator = prepareDataForWindowFunction(new String[]{"name", "value", "year"}, + new ColumnDataType[]{STRING, INT, INT}, DOUBLE, List.of(0), 2, frameType, 0, 0, + getSum(new RexExpression.InputRef(1)), + new Object[][]{ + new Object[]{"A", 14, 2000}, + new Object[]{"A", 10, 2002}, + new Object[]{"A", 20, 2008}, + new Object[]{"A", 15, 2008}, + new Object[]{"B", 10, 2000}, + new Object[]{"B", 20, 2005} + }); + //@formatter:on + + // When: + List resultRows = operator.nextBlock().getContainer(); + + // Then: + //@formatter:off + verifyResultRows(resultRows, List.of(0), Map.of( + "A", List.of( + new Object[]{"A", 14, 2000, 14.0}, + new Object[]{"A", 10, 2002, 10.0}, + new Object[]{"A", 20, 2008, frameType == ROWS ? 20.0 : 35.0}, + new Object[]{"A", 15, 2008, frameType == ROWS ? 15.0 : 35.0} + ), + "B", List.of( + new Object[]{"B", 10, 2000, 10.0}, + new Object[]{"B", 20, 2005, 20.0} + ))); + //@formatter:on + assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)"); + } + + @Test + public void testSumWithCurrentRowLowerAndOffsetFollowingUpper() { + // Given: + //@formatter:off + WindowAggregateOperator operator = prepareDataForWindowFunction(new String[]{"name", "value", "year"}, + new ColumnDataType[]{STRING, INT, INT}, DOUBLE, List.of(0), 2, ROWS, 0, 2, + getSum(new RexExpression.InputRef(1)), + new Object[][]{ + new Object[]{"A", 14, 2000}, + new Object[]{"A", 10, 2002}, + new Object[]{"A", 20, 2008}, + new Object[]{"A", 15, 2008}, + new Object[]{"B", 10, 2000}, + new Object[]{"B", 20, 2005} + }); + //@formatter:on + + // When: + List resultRows = operator.nextBlock().getContainer(); + + // Then: + //@formatter:off + verifyResultRows(resultRows, List.of(0), Map.of( + "A", List.of( + new Object[]{"A", 14, 2000, 44.0}, + new Object[]{"A", 10, 2002, 45.0}, + new Object[]{"A", 20, 2008, 35.0}, + new Object[]{"A", 15, 2008, 15.0} + ), + "B", List.of( + new Object[]{"B", 10, 2000, 30.0}, + new Object[]{"B", 20, 2005, 20.0} + ))); + //@formatter:on + assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)"); + } + + @Test + public void testSumWithOffsetPrecedingLowerAndUnboundedFollowingUpper() { + // Given: + //@formatter:off + WindowAggregateOperator operator = prepareDataForWindowFunction(new String[]{"name", "value", "year"}, + new ColumnDataType[]{STRING, INT, INT}, DOUBLE, List.of(0), 2, ROWS, -1, Integer.MAX_VALUE, + getSum(new RexExpression.InputRef(1)), + new Object[][]{ + new Object[]{"A", 14, 2000}, + new Object[]{"A", 10, 2002}, + new Object[]{"A", 20, 2008}, + new Object[]{"A", 15, 2008}, + new Object[]{"B", 10, 2000}, + new Object[]{"B", 20, 2005} + }); + //@formatter:on + + // When: + List resultRows = operator.nextBlock().getContainer(); + + // Then: + //@formatter:off + verifyResultRows(resultRows, List.of(0), Map.of( + "A", List.of( + new Object[]{"A", 14, 2000, 59.0}, + new Object[]{"A", 10, 2002, 59.0}, + new Object[]{"A", 20, 2008, 45.0}, + new Object[]{"A", 15, 2008, 35.0} + ), + "B", List.of( + new Object[]{"B", 10, 2000, 30.0}, + new Object[]{"B", 20, 2005, 30.0} + ))); + //@formatter:on + assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)"); + } + + @Test + public void testSumWithOffsetFollowingLowerAndUnboundedFollowingUpper() { + // Given: + //@formatter:off + WindowAggregateOperator operator = prepareDataForWindowFunction(new String[]{"name", "value", "year"}, + new ColumnDataType[]{STRING, INT, INT}, DOUBLE, List.of(0), 2, ROWS, 1, Integer.MAX_VALUE, + getSum(new RexExpression.InputRef(1)), + new Object[][]{ + new Object[]{"A", 14, 2000}, + new Object[]{"A", 10, 2002}, + new Object[]{"A", 20, 2008}, + new Object[]{"A", 15, 2008}, + new Object[]{"B", 10, 2000}, + new Object[]{"B", 20, 2005} + }); + //@formatter:on + + // When: + List resultRows = operator.nextBlock().getContainer(); + + // Then: + //@formatter:off + verifyResultRows(resultRows, List.of(0), Map.of( + "A", List.of( + new Object[]{"A", 14, 2000, 45.0}, + new Object[]{"A", 10, 2002, 35.0}, + new Object[]{"A", 20, 2008, 15.0}, + new Object[]{"A", 15, 2008, null} + ), + "B", List.of( + new Object[]{"B", 10, 2000, 20.0}, + new Object[]{"B", 20, 2005, null} + ))); + //@formatter:on + assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)"); + } + + @Test + public void testSumWithOffsetPrecedingLowerAndCurrentRowUpper() { + // Given: + //@formatter:off + WindowAggregateOperator operator = prepareDataForWindowFunction(new String[]{"name", "value", "year"}, + new ColumnDataType[]{STRING, INT, INT}, DOUBLE, List.of(0), 2, ROWS, -2, 0, + getSum(new RexExpression.InputRef(1)), + new Object[][]{ + new Object[]{"A", 14, 2000}, + new Object[]{"A", 10, 2002}, + new Object[]{"A", 20, 2008}, + new Object[]{"A", 15, 2008}, + new Object[]{"B", 10, 2000}, + new Object[]{"B", 20, 2005} + }); + //@formatter:on + + // When: + List resultRows = operator.nextBlock().getContainer(); + + // Then: + //@formatter:off + verifyResultRows(resultRows, List.of(0), Map.of( + "A", List.of( + new Object[]{"A", 14, 2000, 14.0}, + new Object[]{"A", 10, 2002, 24.0}, + new Object[]{"A", 20, 2008, 44.0}, + new Object[]{"A", 15, 2008, 45.0} + ), + "B", List.of( + new Object[]{"B", 10, 2000, 10.0}, + new Object[]{"B", 20, 2005, 30.0} + ))); + //@formatter:on + assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)"); + } + + @Test + public void testSumWithOffsetPrecedingLowerAndOffsetFollowingUpper() { + // Given: + //@formatter:off + WindowAggregateOperator operator = prepareDataForWindowFunction(new String[]{"name", "value", "year"}, + new ColumnDataType[]{STRING, INT, INT}, DOUBLE, List.of(0), 2, ROWS, -1, 2, + getSum(new RexExpression.InputRef(1)), + new Object[][]{ + new Object[]{"A", 14, 2000}, + new Object[]{"A", 10, 2002}, + new Object[]{"A", 20, 2008}, + new Object[]{"A", 15, 2008}, + new Object[]{"B", 10, 2000}, + new Object[]{"B", 20, 2005} + }); + //@formatter:on + + // When: + List resultRows = operator.nextBlock().getContainer(); + + // Then: + //@formatter:off + verifyResultRows(resultRows, List.of(0), Map.of( + "A", List.of( + new Object[]{"A", 14, 2000, 44.0}, + new Object[]{"A", 10, 2002, 59.0}, + new Object[]{"A", 20, 2008, 45.0}, + new Object[]{"A", 15, 2008, 35.0} + ), + "B", List.of( + new Object[]{"B", 10, 2000, 30.0}, + new Object[]{"B", 20, 2005, 30.0} + ))); + //@formatter:on + assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)"); + } + + @Test + public void testSumWithVeryLargeOffsetFollowingUpper() { + // Given: + //@formatter:off + WindowAggregateOperator operator = prepareDataForWindowFunction(new String[]{"name", "value", "year"}, + // Verify if overflows are handled correctly + new ColumnDataType[]{STRING, INT, INT}, DOUBLE, List.of(0), 2, ROWS, -1, 2147483646, + getSum(new RexExpression.InputRef(1)), + new Object[][]{ + new Object[]{"A", 14, 2000}, + new Object[]{"A", 10, 2002}, + new Object[]{"A", 20, 2008}, + new Object[]{"A", 15, 2008}, + new Object[]{"B", 10, 2000}, + new Object[]{"B", 20, 2005} + }); + //@formatter:on + + // When: + List resultRows = operator.nextBlock().getContainer(); + + // Then: + //@formatter:off + verifyResultRows(resultRows, List.of(0), Map.of( + "A", List.of( + new Object[]{"A", 14, 2000, 59.0}, + new Object[]{"A", 10, 2002, 59.0}, + new Object[]{"A", 20, 2008, 45.0}, + new Object[]{"A", 15, 2008, 35.0} + ), + "B", List.of( + new Object[]{"B", 10, 2000, 30.0}, + new Object[]{"B", 20, 2005, 30.0} + ))); + //@formatter:on + assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)"); + } + + @Test + public void testSumWithVeryLargeOffsetFollowingLower() { + // Given: + //@formatter:off + WindowAggregateOperator operator = prepareDataForWindowFunction(new String[]{"name", "value", "year"}, + // Verify if overflows are handled correctly + new ColumnDataType[]{STRING, INT, INT}, DOUBLE, List.of(0), 2, ROWS, 2147483646, 2147483647, + getSum(new RexExpression.InputRef(1)), + new Object[][]{ + new Object[]{"A", 14, 2000}, + new Object[]{"A", 10, 2002}, + new Object[]{"A", 20, 2008}, + new Object[]{"A", 15, 2008}, + new Object[]{"B", 10, 2000}, + new Object[]{"B", 20, 2005} + }); + //@formatter:on + + // When: + List resultRows = operator.nextBlock().getContainer(); + + // Then: + //@formatter:off + verifyResultRows(resultRows, List.of(0), Map.of( + "A", List.of( + new Object[]{"A", 14, 2000, null}, + new Object[]{"A", 10, 2002, null}, + new Object[]{"A", 20, 2008, null}, + new Object[]{"A", 15, 2008, null} + ), + "B", List.of( + new Object[]{"B", 10, 2000, null}, + new Object[]{"B", 20, 2005, null} + ))); + //@formatter:on + assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)"); + } + + @Test + public void testSumWithOffsetPrecedingLowerAndOffsetPrecedingUpper() { + // Given: + //@formatter:off + WindowAggregateOperator operator = prepareDataForWindowFunction(new String[]{"name", "value", "year"}, + new ColumnDataType[]{STRING, INT, INT}, DOUBLE, List.of(0), 2, ROWS, -3, -2, + getSum(new RexExpression.InputRef(1)), + new Object[][]{ + new Object[]{"A", 14, 2000}, + new Object[]{"A", 10, 2002}, + new Object[]{"A", 20, 2008}, + new Object[]{"A", 15, 2008}, + new Object[]{"B", 10, 2000}, + new Object[]{"B", 20, 2005} + }); + //@formatter:on + + // When: + List resultRows = operator.nextBlock().getContainer(); + + // Then: + //@formatter:off + verifyResultRows(resultRows, List.of(0), Map.of( + "A", List.of( + new Object[]{"A", 14, 2000, null}, + new Object[]{"A", 10, 2002, null}, + new Object[]{"A", 20, 2008, 14.0}, + new Object[]{"A", 15, 2008, 24.0} + ), + "B", List.of( + new Object[]{"B", 10, 2000, null}, + new Object[]{"B", 20, 2005, null} + ))); + //@formatter:on + assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)"); + } + + @Test + public void testSumWithOffsetFollowingLowerAndOffsetFollowingUpper() { + // Given: + //@formatter:off + WindowAggregateOperator operator = prepareDataForWindowFunction(new String[]{"name", "value", "year"}, + new ColumnDataType[]{STRING, INT, INT}, DOUBLE, List.of(0), 2, ROWS, 1, 2, + getSum(new RexExpression.InputRef(1)), + new Object[][]{ + new Object[]{"A", 14, 2000}, + new Object[]{"A", 10, 2002}, + new Object[]{"A", 20, 2008}, + new Object[]{"A", 15, 2008}, + new Object[]{"B", 10, 2000}, + new Object[]{"B", 20, 2005} + }); + //@formatter:on + + // When: + List resultRows = operator.nextBlock().getContainer(); + + // Then: + //@formatter:off + verifyResultRows(resultRows, List.of(0), Map.of( + "A", List.of( + new Object[]{"A", 14, 2000, 30.0}, + new Object[]{"A", 10, 2002, 35.0}, + new Object[]{"A", 20, 2008, 15.0}, + new Object[]{"A", 15, 2008, null} + ), + "B", List.of( + new Object[]{"B", 10, 2000, 20.0}, + new Object[]{"B", 20, 2005, null} + ))); + //@formatter:on + assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)"); + } + + @Test + public void testSumWithSamePartitionAndCollationKey() { + // Given: + //@formatter:off + WindowAggregateOperator operator = prepareDataForWindowFunction(new String[]{"name", "value", "year"}, + new ColumnDataType[]{STRING, INT, INT}, DOUBLE, List.of(0), 0, RANGE, Integer.MIN_VALUE, 0, + getSum(new RexExpression.InputRef(1)), + new Object[][]{ + new Object[]{"A", 14, 2000}, + new Object[]{"A", 10, 2002}, + new Object[]{"A", 20, 2008}, + new Object[]{"A", 15, 2008}, + new Object[]{"B", 10, 2000}, + new Object[]{"B", 20, 2005} + }); + //@formatter:on + + // When: + List resultRows = operator.nextBlock().getContainer(); + + // Then: + //@formatter:off + verifyResultRows(resultRows, List.of(0), Map.of( + "A", List.of( + new Object[]{"A", 14, 2000, 59.0}, + new Object[]{"A", 10, 2002, 59.0}, + new Object[]{"A", 20, 2008, 59.0}, + new Object[]{"A", 15, 2008, 59.0} + ), + "B", List.of( + new Object[]{"B", 10, 2000, 30.0}, + new Object[]{"B", 20, 2005, 30.0} + ))); + //@formatter:on + assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)"); + } + + @Test(dataProvider = "windowFrameTypes") + public void testFirstValueWithUnboundedPrecedingLowerAndCurrentRowUpper(WindowNode.WindowFrameType frameType) { + // Given: + //@formatter:off + WindowAggregateOperator operator = prepareDataForWindowFunction(new String[]{"name", "value", "year"}, + new ColumnDataType[]{STRING, INT, INT}, INT, List.of(0), 2, frameType, Integer.MIN_VALUE, 0, + getFirstValue(new RexExpression.InputRef(1)), + new Object[][]{ + new Object[]{"A", 14, 2000}, + new Object[]{"A", 10, 2002}, + new Object[]{"A", 20, 2008}, + new Object[]{"A", 15, 2008}, + new Object[]{"B", 10, 2000}, + new Object[]{"B", 20, 2005} + }); + //@formatter:on + + // When: + List resultRows = operator.nextBlock().getContainer(); + + // Then: + //@formatter:off + verifyResultRows(resultRows, List.of(0), Map.of( + "A", List.of( + new Object[]{"A", 14, 2000, 14}, + new Object[]{"A", 10, 2002, 14}, + new Object[]{"A", 20, 2008, 14}, + new Object[]{"A", 15, 2008, 14} + ), + "B", List.of( + new Object[]{"B", 10, 2000, 10}, + new Object[]{"B", 20, 2005, 10} + ))); + //@formatter:on + assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)"); + } + + @Test(dataProvider = "windowFrameTypes") + public void testFirstValueWithUnboundedPrecedingLowerAndUnboundedFollowingUpper( + WindowNode.WindowFrameType frameType) { + // Given: + //@formatter:off + WindowAggregateOperator operator = prepareDataForWindowFunction(new String[]{"name", "value", "year"}, + new ColumnDataType[]{STRING, INT, INT}, INT, List.of(0), 2, frameType, Integer.MIN_VALUE, Integer.MAX_VALUE, + getFirstValue(new RexExpression.InputRef(1)), + new Object[][]{ + new Object[]{"A", 14, 2000}, + new Object[]{"A", 10, 2002}, + new Object[]{"A", 20, 2008}, + new Object[]{"A", 15, 2008}, + new Object[]{"B", 10, 2000}, + new Object[]{"B", 20, 2005} + }); + //@formatter:on + + // When: + List resultRows = operator.nextBlock().getContainer(); + + // Then: + //@formatter:off + verifyResultRows(resultRows, List.of(0), Map.of( + "A", List.of( + new Object[]{"A", 14, 2000, 14}, + new Object[]{"A", 10, 2002, 14}, + new Object[]{"A", 20, 2008, 14}, + new Object[]{"A", 15, 2008, 14} + ), + "B", List.of( + new Object[]{"B", 10, 2000, 10}, + new Object[]{"B", 20, 2005, 10} + ))); + //@formatter:on + assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)"); + } + + @Test + public void testFirstValueWithUnboundedPrecedingLowerAndOffsetPrecedingUpper() { + // Given: + //@formatter:off + WindowAggregateOperator operator = prepareDataForWindowFunction(new String[]{"name", "value", "year"}, + new ColumnDataType[]{STRING, INT, INT}, INT, List.of(0), 2, ROWS, Integer.MIN_VALUE, -2, + getFirstValue(new RexExpression.InputRef(1)), + new Object[][]{ + new Object[]{"A", 14, 2000}, + new Object[]{"A", 10, 2002}, + new Object[]{"A", 20, 2008}, + new Object[]{"A", 15, 2008}, + new Object[]{"B", 10, 2000}, + new Object[]{"B", 20, 2005} + }); + //@formatter:on + + // When: + List resultRows = operator.nextBlock().getContainer(); + + // Then: + //@formatter:off + verifyResultRows(resultRows, List.of(0), Map.of( + "A", List.of( + new Object[]{"A", 14, 2000, null}, + new Object[]{"A", 10, 2002, null}, + new Object[]{"A", 20, 2008, 14}, + new Object[]{"A", 15, 2008, 14} + ), + "B", List.of( + new Object[]{"B", 10, 2000, null}, + new Object[]{"B", 20, 2005, null} + ))); + //@formatter:on + assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)"); + } + + @Test + public void testFirstValueWithUnboundedPrecedingLowerAndOffsetFollowingUpper() { + // Given: + //@formatter:off + WindowAggregateOperator operator = prepareDataForWindowFunction(new String[]{"name", "value", "year"}, + new ColumnDataType[]{STRING, INT, INT}, INT, List.of(0), 2, ROWS, Integer.MIN_VALUE, 2, + getFirstValue(new RexExpression.InputRef(1)), + new Object[][]{ + new Object[]{"A", 14, 2000}, + new Object[]{"A", 10, 2002}, + new Object[]{"A", 20, 2008}, + new Object[]{"A", 15, 2008}, + new Object[]{"B", 10, 2000}, + new Object[]{"B", 20, 2005} + }); + //@formatter:on + + // When: + List resultRows = operator.nextBlock().getContainer(); + + // Then: + //@formatter:off + verifyResultRows(resultRows, List.of(0), Map.of( + "A", List.of( + new Object[]{"A", 14, 2000, 14}, + new Object[]{"A", 10, 2002, 14}, + new Object[]{"A", 20, 2008, 14}, + new Object[]{"A", 15, 2008, 14} + ), + "B", List.of( + new Object[]{"B", 10, 2000, 10}, + new Object[]{"B", 20, 2005, 10} + ))); + //@formatter:on + assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)"); + } + + @Test(dataProvider = "windowFrameTypes") + public void testFirstValueWithCurrentRowLowerAndUnboundedFollowingUpper(WindowNode.WindowFrameType frameType) { + // Given: + //@formatter:off + WindowAggregateOperator operator = prepareDataForWindowFunction(new String[]{"name", "value", "year"}, + new ColumnDataType[]{STRING, INT, INT}, INT, List.of(0), 2, frameType, 0, Integer.MAX_VALUE, + getFirstValue(new RexExpression.InputRef(1)), + new Object[][]{ + new Object[]{"A", 14, 2000}, + new Object[]{"A", 10, 2002}, + new Object[]{"A", 20, 2008}, + new Object[]{"A", 15, 2008}, + new Object[]{"B", 10, 2000}, + new Object[]{"B", 20, 2005} + }); + //@formatter:on + + // When: + List resultRows = operator.nextBlock().getContainer(); + + // Then: + //@formatter:off + verifyResultRows(resultRows, List.of(0), Map.of( + "A", List.of( + new Object[]{"A", 14, 2000, 14}, + new Object[]{"A", 10, 2002, 10}, + new Object[]{"A", 20, 2008, 20}, + new Object[]{"A", 15, 2008, frameType == ROWS ? 15 : 20} + ), + "B", List.of( + new Object[]{"B", 10, 2000, 10}, + new Object[]{"B", 20, 2005, 20} + ))); + //@formatter:on assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)"); } + @Test(dataProvider = "windowFrameTypes") + public void testFirstValueWithCurrentRowLowerAndCurrentRowUpper(WindowNode.WindowFrameType frameType) { + // Given: + //@formatter:off + WindowAggregateOperator operator = prepareDataForWindowFunction(new String[]{"name", "value", "year"}, + new ColumnDataType[]{STRING, INT, INT}, INT, List.of(0), 2, frameType, 0, 0, + getFirstValue(new RexExpression.InputRef(1)), + new Object[][]{ + new Object[]{"A", 14, 2000}, + new Object[]{"A", 10, 2002}, + new Object[]{"A", 20, 2008}, + new Object[]{"A", 15, 2008}, + new Object[]{"B", 10, 2000}, + new Object[]{"B", 20, 2005} + }); + //@formatter:on + + // When: + List resultRows = operator.nextBlock().getContainer(); + + // Then: + //@formatter:off + verifyResultRows(resultRows, List.of(0), Map.of( + "A", List.of( + new Object[]{"A", 14, 2000, 14}, + new Object[]{"A", 10, 2002, 10}, + new Object[]{"A", 20, 2008, 20}, + new Object[]{"A", 15, 2008, frameType == ROWS ? 15 : 20} + ), + "B", List.of( + new Object[]{"B", 10, 2000, 10}, + new Object[]{"B", 20, 2005, 20} + ))); + //@formatter:on + assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)"); + } + + @Test + public void testFirstValueWithCurrentRowLowerAndOffsetFollowingUpper() { + // Given: + //@formatter:off + WindowAggregateOperator operator = prepareDataForWindowFunction(new String[]{"name", "value", "year"}, + new ColumnDataType[]{STRING, INT, INT}, INT, List.of(0), 2, ROWS, 0, 2, + getFirstValue(new RexExpression.InputRef(1)), + new Object[][]{ + new Object[]{"A", 14, 2000}, + new Object[]{"A", 10, 2002}, + new Object[]{"A", 20, 2008}, + new Object[]{"A", 15, 2008}, + new Object[]{"B", 10, 2000}, + new Object[]{"B", 20, 2005} + }); + //@formatter:on + + // When: + List resultRows = operator.nextBlock().getContainer(); + + // Then: + //@formatter:off + verifyResultRows(resultRows, List.of(0), Map.of( + "A", List.of( + new Object[]{"A", 14, 2000, 14}, + new Object[]{"A", 10, 2002, 10}, + new Object[]{"A", 20, 2008, 20}, + new Object[]{"A", 15, 2008, 15} + ), + "B", List.of( + new Object[]{"B", 10, 2000, 10}, + new Object[]{"B", 20, 2005, 20} + ))); + //@formatter:on + assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)"); + } + + @Test + public void testFirstValueWithOffsetPrecedingLowerAndOffsetFollowingUpper() { + // Given: + //@formatter:off + WindowAggregateOperator operator = prepareDataForWindowFunction(new String[]{"name", "value", "year"}, + new ColumnDataType[]{STRING, INT, INT}, INT, List.of(0), 2, ROWS, -1, 2, + getFirstValue(new RexExpression.InputRef(1)), + new Object[][]{ + new Object[]{"A", 14, 2000}, + new Object[]{"A", 10, 2002}, + new Object[]{"A", 20, 2008}, + new Object[]{"A", 15, 2008}, + new Object[]{"B", 10, 2000}, + new Object[]{"B", 20, 2005} + }); + //@formatter:on + + // When: + List resultRows = operator.nextBlock().getContainer(); + + // Then: + //@formatter:off + verifyResultRows(resultRows, List.of(0), Map.of( + "A", List.of( + new Object[]{"A", 14, 2000, 14}, + new Object[]{"A", 10, 2002, 14}, + new Object[]{"A", 20, 2008, 10}, + new Object[]{"A", 15, 2008, 20} + ), + "B", List.of( + new Object[]{"B", 10, 2000, 10}, + new Object[]{"B", 20, 2005, 10} + ))); + //@formatter:on + assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)"); + } + + @Test + public void testFirstValueWithOffsetPrecedingLowerAndOffsetPrecedingUpper() { + // Given: + //@formatter:off + WindowAggregateOperator operator = prepareDataForWindowFunction(new String[]{"name", "value", "year"}, + new ColumnDataType[]{STRING, INT, INT}, INT, List.of(0), 2, ROWS, -2, -1, + getFirstValue(new RexExpression.InputRef(1)), + new Object[][]{ + new Object[]{"A", 14, 2000}, + new Object[]{"A", 10, 2002}, + new Object[]{"A", 20, 2008}, + new Object[]{"A", 15, 2008}, + new Object[]{"B", 10, 2000}, + new Object[]{"B", 20, 2005} + }); + //@formatter:on + + // When: + List resultRows = operator.nextBlock().getContainer(); + + // Then: + //@formatter:off + verifyResultRows(resultRows, List.of(0), Map.of( + "A", List.of( + new Object[]{"A", 14, 2000, null}, + new Object[]{"A", 10, 2002, 14}, + new Object[]{"A", 20, 2008, 14}, + new Object[]{"A", 15, 2008, 10} + ), + "B", List.of( + new Object[]{"B", 10, 2000, null}, + new Object[]{"B", 20, 2005, 10} + ))); + //@formatter:on + assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)"); + } + + @Test(dataProvider = "windowFrameTypes") + public void testLastValueWithUnboundedPrecedingLowerAndCurrentRowUpper(WindowNode.WindowFrameType frameType) { + // Given: + //@formatter:off + WindowAggregateOperator operator = prepareDataForWindowFunction(new String[]{"name", "value", "year"}, + new ColumnDataType[]{STRING, INT, INT}, INT, List.of(0), 2, frameType, Integer.MIN_VALUE, 0, + getLastValue(new RexExpression.InputRef(1)), + new Object[][]{ + new Object[]{"A", 14, 2000}, + new Object[]{"A", 10, 2002}, + new Object[]{"A", 20, 2008}, + new Object[]{"A", 15, 2008}, + new Object[]{"B", 10, 2000}, + new Object[]{"B", 20, 2005} + }); + //@formatter:on + + // When: + List resultRows = operator.nextBlock().getContainer(); + + // Then: + //@formatter:off + verifyResultRows(resultRows, List.of(0), Map.of( + "A", List.of( + new Object[]{"A", 14, 2000, 14}, + new Object[]{"A", 10, 2002, 10}, + new Object[]{"A", 20, 2008, frameType == ROWS ? 20 : 15}, + new Object[]{"A", 15, 2008, 15} + ), + "B", List.of( + new Object[]{"B", 10, 2000, 10}, + new Object[]{"B", 20, 2005, 20} + ))); + //@formatter:on + assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)"); + } + + @Test(dataProvider = "windowFrameTypes") + public void testLastValueWithUnboundedPrecedingLowerAndUnboundedFollowingUpper( + WindowNode.WindowFrameType frameType) { + // Given: + //@formatter:off + WindowAggregateOperator operator = prepareDataForWindowFunction(new String[]{"name", "value", "year"}, + new ColumnDataType[]{STRING, INT, INT}, INT, List.of(0), 2, frameType, Integer.MIN_VALUE, Integer.MAX_VALUE, + getLastValue(new RexExpression.InputRef(1)), + new Object[][]{ + new Object[]{"A", 14, 2000}, + new Object[]{"A", 10, 2002}, + new Object[]{"A", 20, 2008}, + new Object[]{"A", 15, 2008}, + new Object[]{"B", 10, 2000}, + new Object[]{"B", 20, 2005} + }); + //@formatter:on + + // When: + List resultRows = operator.nextBlock().getContainer(); + + // Then: + //@formatter:off + verifyResultRows(resultRows, List.of(0), Map.of( + "A", List.of( + new Object[]{"A", 14, 2000, 15}, + new Object[]{"A", 10, 2002, 15}, + new Object[]{"A", 20, 2008, 15}, + new Object[]{"A", 15, 2008, 15} + ), + "B", List.of( + new Object[]{"B", 10, 2000, 20}, + new Object[]{"B", 20, 2005, 20} + ))); + //@formatter:on + assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)"); + } + + @Test + public void testLastValueWithUnboundedPrecedingLowerAndOffsetPrecedingUpper() { + // Given: + //@formatter:off + WindowAggregateOperator operator = prepareDataForWindowFunction(new String[]{"name", "value", "year"}, + new ColumnDataType[]{STRING, INT, INT}, INT, List.of(0), 2, ROWS, Integer.MIN_VALUE, -2, + getLastValue(new RexExpression.InputRef(1)), + new Object[][]{ + new Object[]{"A", 14, 2000}, + new Object[]{"A", 10, 2002}, + new Object[]{"A", 20, 2008}, + new Object[]{"A", 15, 2008}, + new Object[]{"B", 10, 2000}, + new Object[]{"B", 20, 2005} + }); + //@formatter:on + + // When: + List resultRows = operator.nextBlock().getContainer(); + + // Then: + //@formatter:off + verifyResultRows(resultRows, List.of(0), Map.of( + "A", List.of( + new Object[]{"A", 14, 2000, null}, + new Object[]{"A", 10, 2002, null}, + new Object[]{"A", 20, 2008, 14}, + new Object[]{"A", 15, 2008, 10} + ), + "B", List.of( + new Object[]{"B", 10, 2000, null}, + new Object[]{"B", 20, 2005, null} + ))); + //@formatter:on + assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)"); + } + + @Test + public void testLastValueWithUnboundedPrecedingLowerAndOffsetFollowingUpper() { + // Given: + //@formatter:off + WindowAggregateOperator operator = prepareDataForWindowFunction(new String[]{"name", "value", "year"}, + new ColumnDataType[]{STRING, INT, INT}, INT, List.of(0), 2, ROWS, Integer.MIN_VALUE, 2, + getLastValue(new RexExpression.InputRef(1)), + new Object[][]{ + new Object[]{"A", 14, 2000}, + new Object[]{"A", 10, 2002}, + new Object[]{"A", 20, 2008}, + new Object[]{"A", 15, 2008}, + new Object[]{"B", 10, 2000}, + new Object[]{"B", 20, 2005} + }); + //@formatter:on + + // When: + List resultRows = operator.nextBlock().getContainer(); + + // Then: + //@formatter:off + verifyResultRows(resultRows, List.of(0), Map.of( + "A", List.of( + new Object[]{"A", 14, 2000, 20}, + new Object[]{"A", 10, 2002, 15}, + new Object[]{"A", 20, 2008, 15}, + new Object[]{"A", 15, 2008, 15} + ), + "B", List.of( + new Object[]{"B", 10, 2000, 20}, + new Object[]{"B", 20, 2005, 20} + ))); + //@formatter:on + assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)"); + } + + @Test(dataProvider = "windowFrameTypes") + public void testLastValueWithCurrentRowLowerAndUnboundedFollowingUpper(WindowNode.WindowFrameType frameType) { + // Given: + //@formatter:off + WindowAggregateOperator operator = prepareDataForWindowFunction(new String[]{"name", "value", "year"}, + new ColumnDataType[]{STRING, INT, INT}, INT, List.of(0), 2, frameType, 0, Integer.MAX_VALUE, + getLastValue(new RexExpression.InputRef(1)), + new Object[][]{ + new Object[]{"A", 14, 2000}, + new Object[]{"A", 10, 2002}, + new Object[]{"A", 20, 2008}, + new Object[]{"A", 15, 2008}, + new Object[]{"B", 10, 2000}, + new Object[]{"B", 20, 2005} + }); + //@formatter:on + + // When: + List resultRows = operator.nextBlock().getContainer(); + + // Then: + //@formatter:off + verifyResultRows(resultRows, List.of(0), Map.of( + "A", List.of( + new Object[]{"A", 14, 2000, 15}, + new Object[]{"A", 10, 2002, 15}, + new Object[]{"A", 20, 2008, 15}, + new Object[]{"A", 15, 2008, 15} + ), + "B", List.of( + new Object[]{"B", 10, 2000, 20}, + new Object[]{"B", 20, 2005, 20} + ))); + //@formatter:on + assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)"); + } + + @Test(dataProvider = "windowFrameTypes") + public void testLastValueWithCurrentRowLowerAndCurrentRowUpper(WindowNode.WindowFrameType frameType) { + // Given: + //@formatter:off + WindowAggregateOperator operator = prepareDataForWindowFunction(new String[]{"name", "value", "year"}, + new ColumnDataType[]{STRING, INT, INT}, INT, List.of(0), 2, frameType, 0, 0, + getLastValue(new RexExpression.InputRef(1)), + new Object[][]{ + new Object[]{"A", 14, 2000}, + new Object[]{"A", 10, 2002}, + new Object[]{"A", 20, 2008}, + new Object[]{"A", 15, 2008}, + new Object[]{"B", 10, 2000}, + new Object[]{"B", 20, 2005} + }); + //@formatter:on + + // When: + List resultRows = operator.nextBlock().getContainer(); + + // Then: + //@formatter:off + verifyResultRows(resultRows, List.of(0), Map.of( + "A", List.of( + new Object[]{"A", 14, 2000, 14}, + new Object[]{"A", 10, 2002, 10}, + new Object[]{"A", 20, 2008, frameType == ROWS ? 20 : 15}, + new Object[]{"A", 15, 2008, 15} + ), + "B", List.of( + new Object[]{"B", 10, 2000, 10}, + new Object[]{"B", 20, 2005, 20} + ))); + //@formatter:on + assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)"); + } + + @Test + public void testLastValueWithCurrentRowLowerAndOffsetFollowingUpper() { + // Given: + //@formatter:off + WindowAggregateOperator operator = prepareDataForWindowFunction(new String[]{"name", "value", "year"}, + new ColumnDataType[]{STRING, INT, INT}, INT, List.of(0), 2, ROWS, 0, 2, + getLastValue(new RexExpression.InputRef(1)), + new Object[][]{ + new Object[]{"A", 14, 2000}, + new Object[]{"A", 10, 2002}, + new Object[]{"A", 20, 2008}, + new Object[]{"A", 15, 2008}, + new Object[]{"B", 10, 2000}, + new Object[]{"B", 20, 2005} + }); + //@formatter:on + + // When: + List resultRows = operator.nextBlock().getContainer(); + + // Then: + //@formatter:off + verifyResultRows(resultRows, List.of(0), Map.of( + "A", List.of( + new Object[]{"A", 14, 2000, 20}, + new Object[]{"A", 10, 2002, 15}, + new Object[]{"A", 20, 2008, 15}, + new Object[]{"A", 15, 2008, 15} + ), + "B", List.of( + new Object[]{"B", 10, 2000, 20}, + new Object[]{"B", 20, 2005, 20} + ))); + //@formatter:on + assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)"); + } + + @Test + public void testLastValueWithOffsetPrecedingLowerAndOffsetFollowingUpper() { + // Given: + //@formatter:off + WindowAggregateOperator operator = prepareDataForWindowFunction(new String[]{"name", "value", "year"}, + new ColumnDataType[]{STRING, INT, INT}, INT, List.of(0), 2, ROWS, -1, 2, + getLastValue(new RexExpression.InputRef(1)), + new Object[][]{ + new Object[]{"A", 14, 2000}, + new Object[]{"A", 10, 2002}, + new Object[]{"A", 20, 2008}, + new Object[]{"A", 15, 2008}, + new Object[]{"B", 10, 2000}, + new Object[]{"B", 20, 2005} + }); + //@formatter:on + + // When: + List resultRows = operator.nextBlock().getContainer(); + + // Then: + //@formatter:off + verifyResultRows(resultRows, List.of(0), Map.of( + "A", List.of( + new Object[]{"A", 14, 2000, 20}, + new Object[]{"A", 10, 2002, 15}, + new Object[]{"A", 20, 2008, 15}, + new Object[]{"A", 15, 2008, 15} + ), + "B", List.of( + new Object[]{"B", 10, 2000, 20}, + new Object[]{"B", 20, 2005, 20} + ))); + //@formatter:on + assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)"); + } + + @Test + public void testLastValueWithOffsetPrecedingLowerAndOffsetPrecedingUpper() { + // Given: + //@formatter:off + WindowAggregateOperator operator = prepareDataForWindowFunction(new String[]{"name", "value", "year"}, + new ColumnDataType[]{STRING, INT, INT}, INT, List.of(0), 2, ROWS, -2, -1, + getLastValue(new RexExpression.InputRef(1)), + new Object[][]{ + new Object[]{"A", 14, 2000}, + new Object[]{"A", 10, 2002}, + new Object[]{"A", 20, 2008}, + new Object[]{"A", 15, 2008}, + new Object[]{"B", 10, 2000}, + new Object[]{"B", 20, 2005} + }); + //@formatter:on + + // When: + List resultRows = operator.nextBlock().getContainer(); + + // Then: + //@formatter:off + verifyResultRows(resultRows, List.of(0), Map.of( + "A", List.of( + new Object[]{"A", 14, 2000, null}, + new Object[]{"A", 10, 2002, 14}, + new Object[]{"A", 20, 2008, 10}, + new Object[]{"A", 15, 2008, 20} + ), + "B", List.of( + new Object[]{"B", 10, 2000, null}, + new Object[]{"B", 20, 2005, 10} + ))); + //@formatter:on + assertTrue(operator.nextBlock().isSuccessfulEndOfStreamBlock(), "Second block is EOS (done processing)"); + } + + private WindowAggregateOperator prepareDataForWindowFunction(String[] inputSchemaCols, + ColumnDataType[] inputSchemaColTypes, ColumnDataType outputType, List partitionKeys, + int collationFieldIndex, WindowNode.WindowFrameType frameType, int windowFrameLowerBound, + int windowFrameUpperBound, RexExpression.FunctionCall functionCall, Object[][] rows) { + DataSchema inputSchema = new DataSchema(inputSchemaCols, inputSchemaColTypes); + when(_input.nextBlock()).thenReturn(OperatorTestUtil.block(inputSchema, rows)) + .thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0)); + + String[] outputSchemaCols = new String[inputSchemaCols.length + 1]; + System.arraycopy(inputSchemaCols, 0, outputSchemaCols, 0, inputSchemaCols.length); + outputSchemaCols[inputSchemaCols.length] = functionCall.getFunctionName().toLowerCase(); + + ColumnDataType[] outputSchemaColTypes = new ColumnDataType[inputSchemaColTypes.length + 1]; + System.arraycopy(inputSchemaColTypes, 0, outputSchemaColTypes, 0, inputSchemaColTypes.length); + outputSchemaColTypes[inputSchemaCols.length] = outputType; + + DataSchema resultSchema = new DataSchema(outputSchemaCols, outputSchemaColTypes); + List aggCalls = List.of(functionCall); + List collations = List.of(new RelFieldCollation(collationFieldIndex)); + return getOperator(inputSchema, resultSchema, partitionKeys, collations, aggCalls, frameType, windowFrameLowerBound, + windowFrameUpperBound); + } + + @Test + public void testShouldThrowOnWindowFrameWithInvalidOffsetBounds() { + // Given: + DataSchema inputSchema = new DataSchema(new String[]{"group", "arg"}, new ColumnDataType[]{INT, STRING}); + when(_input.nextBlock()).thenReturn(OperatorTestUtil.block(inputSchema, new Object[]{2, "foo"})) + .thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0)); + DataSchema resultSchema = + new DataSchema(new String[]{"group", "arg", "sum"}, new ColumnDataType[]{INT, STRING, DOUBLE}); + List keys = List.of(0); + List aggCalls = List.of(getSum(new RexExpression.InputRef(1))); + + // Then: + IllegalStateException e = Assert.expectThrows(IllegalStateException.class, + () -> getOperator(inputSchema, resultSchema, keys, List.of(), aggCalls, ROWS, 5, 2)); + assertEquals(e.getMessage(), "Window frame lower bound can't be greater than upper bound"); + + e = Assert.expectThrows(IllegalStateException.class, + () -> getOperator(inputSchema, resultSchema, keys, List.of(), aggCalls, ROWS, -2, -3)); + assertEquals(e.getMessage(), "Window frame lower bound can't be greater than upper bound"); + } + + @Test + public void testShouldThrowOnWindowFrameWithOffsetBoundsForRange() { + // TODO: Remove this test when support for RANGE window frames with offset PRECEDING / FOLLOWING is added + // Given: + DataSchema inputSchema = new DataSchema(new String[]{"group", "arg"}, new ColumnDataType[]{INT, STRING}); + when(_input.nextBlock()).thenReturn(OperatorTestUtil.block(inputSchema, new Object[]{2, "foo"})) + .thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0)); + DataSchema resultSchema = + new DataSchema(new String[]{"group", "arg", "sum"}, new ColumnDataType[]{INT, STRING, DOUBLE}); + List keys = List.of(0); + List aggCalls = List.of(getSum(new RexExpression.InputRef(1))); + + // Then: + IllegalStateException e = Assert.expectThrows(IllegalStateException.class, + () -> getOperator(inputSchema, resultSchema, keys, List.of(), aggCalls, WindowNode.WindowFrameType.RANGE, 5, + Integer.MAX_VALUE)); + assertEquals(e.getMessage(), "RANGE window frame with offset PRECEDING / FOLLOWING is not supported"); + + e = Assert.expectThrows(IllegalStateException.class, + () -> getOperator(inputSchema, resultSchema, keys, List.of(), aggCalls, WindowNode.WindowFrameType.RANGE, + Integer.MAX_VALUE, 5)); + assertEquals(e.getMessage(), "RANGE window frame with offset PRECEDING / FOLLOWING is not supported"); + } + private WindowAggregateOperator getOperator(DataSchema inputSchema, DataSchema resultSchema, List keys, List collations, List aggCalls, WindowNode.WindowFrameType windowFrameType, int lowerBound, int upperBound, PlanNode.NodeHint nodeHint) { @@ -666,6 +1952,14 @@ private static RexExpression.FunctionCall getSum(RexExpression arg) { return new RexExpression.FunctionCall(ColumnDataType.INT, SqlKind.SUM.name(), List.of(arg)); } + private static RexExpression.FunctionCall getFirstValue(RexExpression arg) { + return new RexExpression.FunctionCall(ColumnDataType.INT, SqlKind.FIRST_VALUE.name(), List.of(arg)); + } + + private static RexExpression.FunctionCall getLastValue(RexExpression arg) { + return new RexExpression.FunctionCall(ColumnDataType.INT, SqlKind.LAST_VALUE.name(), List.of(arg)); + } + private static void verifyResultRows(List resultRows, List keys, Map> expectedKeyedRows) { int numKeys = keys.size(); @@ -698,4 +1992,12 @@ private static void verifyResultRows(List resultRows, List e assertEquals(resultRows.get(i), expectedRows.get(i)); } } + + @DataProvider(name = "windowFrameTypes") + public Object[][] getWindowFrameTypes() { + return new Object[][]{ + {ROWS}, + {WindowNode.WindowFrameType.RANGE} + }; + } }