From e7220fa9b585ad71723df1b6688135b48a7df15d Mon Sep 17 00:00:00 2001 From: Erich <134291879+ege-st@users.noreply.github.com> Date: Tue, 12 Mar 2024 19:53:37 -0400 Subject: [PATCH] Allow Aggregations in Case Expressions (#12613) --- .../pinot/sql/parsers/CalciteSqlParser.java | 8 --- .../sql/parsers/CalciteSqlCompilerTest.java | 70 ++++++++++++++----- .../tests/BaseClusterIntegrationTestSet.java | 9 +++ .../org/apache/pinot/query/QueryTestSet.java | 3 + 4 files changed, 63 insertions(+), 27 deletions(-) diff --git a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/CalciteSqlParser.java b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/CalciteSqlParser.java index 3d216bd6436..95ed6af5437 100644 --- a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/CalciteSqlParser.java +++ b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/CalciteSqlParser.java @@ -747,18 +747,10 @@ private static Expression toExpression(SqlNode node) { for (int i = 0; i < whenOperands.size(); i++) { SqlNode whenSqlNode = whenOperands.get(i); Expression whenExpression = toExpression(whenSqlNode); - if (isAggregateExpression(whenExpression)) { - throw new SqlCompilationException( - "Aggregation functions inside WHEN Clause is not supported - " + whenSqlNode); - } caseFuncExpr.getFunctionCall().addToOperands(whenExpression); SqlNode thenSqlNode = thenOperands.get(i); Expression thenExpression = toExpression(thenSqlNode); - if (isAggregateExpression(thenExpression)) { - throw new SqlCompilationException( - "Aggregation functions inside THEN Clause is not supported - " + thenSqlNode); - } caseFuncExpr.getFunctionCall().addToOperands(thenExpression); } Expression elseExpression = toExpression(elseOperand); diff --git a/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java b/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java index 02f110603f5..ec039c76134 100644 --- a/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java +++ b/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java @@ -141,25 +141,57 @@ public void testCaseWhenStatements() { Assert.assertEquals(caseFunc.getOperands().get(6).getLiteral().getFieldValue(), 0L); } - @Test(expectedExceptions = SqlCompilationException.class) - public void testInvalidCaseWhenStatements() { - // Not support Aggregation functions in case statements. - try { - //@formatter:off - CalciteSqlParser.compileToPinotQuery( - "SELECT OrderID, Quantity,\n" - + "CASE\n" - + " WHEN sum(Quantity) > 30 THEN 'The quantity is greater than 30'\n" - + " WHEN sum(Quantity) = 30 THEN 'The quantity is 30'\n" - + " ELSE 'The quantity is under 30'\n" - + "END AS QuantityText\n" - + "FROM OrderDetails"); - //@formatter:on - } catch (SqlCompilationException e) { - Assert.assertEquals(e.getMessage(), - "Aggregation functions inside WHEN Clause is not supported - SUM(`Quantity`) > 30"); - throw e; - } + @Test + public void testAggregationInCaseWhenStatementsWithGroupBy() { + //@formatter:off + PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery( + "SELECT OrderID, SUM(Quantity),\n" + + "CASE\n" + + " WHEN sum(Quantity) > 30 THEN 'The quantity is greater than 30'\n" + + " WHEN sum(Quantity) = 30 THEN 'The quantity is 30'\n" + + " ELSE 'The quantity is under 30'\n" + + "END AS QuantityText\n" + + "FROM OrderDetails\n" + + "GROUP BY OrderID"); + //@formatter:on + Function caseStm = pinotQuery.getSelectList().get(2).getFunctionCall().getOperands().get(0).getFunctionCall(); + Assert.assertEquals(caseStm.getOperator(), "case"); + Expression firstWhen = caseStm.getOperands().get(0); + Assert.assertEquals(firstWhen.getFunctionCall().getOperands().get(0).getFunctionCall().getOperator(), "sum"); + Assert.assertEquals( + firstWhen.getFunctionCall().getOperands().get(0).getFunctionCall().getOperands().get(0).getIdentifier() + .getName(), "Quantity"); + Expression secondWhen = caseStm.getOperands().get(2); + Assert.assertEquals(secondWhen.getFunctionCall().getOperands().get(0).getFunctionCall().getOperator(), "sum"); + Assert.assertEquals( + secondWhen.getFunctionCall().getOperands().get(0).getFunctionCall().getOperands().get(0).getIdentifier() + .getName(), "Quantity"); + } + + @Test + public void testAggregationInCaseWhenStatements() { + //@formatter:off + PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery( + "SELECT sum(Quantity),\n" + + "CASE\n" + + " WHEN sum(Quantity) > 30 THEN 'The quantity is greater than 30'\n" + + " WHEN sum(Quantity) = 30 THEN 'The quantity is 30'\n" + + " ELSE 'The quantity is under 30'\n" + + "END AS QuantityText\n" + + "FROM OrderDetails\n"); + //@formatter:on + Function caseStm = pinotQuery.getSelectList().get(1).getFunctionCall().getOperands().get(0).getFunctionCall(); + Assert.assertEquals(caseStm.getOperator(), "case"); + Expression firstWhen = caseStm.getOperands().get(0); + Assert.assertEquals(firstWhen.getFunctionCall().getOperands().get(0).getFunctionCall().getOperator(), "sum"); + Assert.assertEquals( + firstWhen.getFunctionCall().getOperands().get(0).getFunctionCall().getOperands().get(0).getIdentifier() + .getName(), "Quantity"); + Expression secondWhen = caseStm.getOperands().get(2); + Assert.assertEquals(secondWhen.getFunctionCall().getOperands().get(0).getFunctionCall().getOperator(), "sum"); + Assert.assertEquals( + secondWhen.getFunctionCall().getOperands().get(0).getFunctionCall().getOperands().get(0).getIdentifier() + .getName(), "Quantity"); } @Test diff --git a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/BaseClusterIntegrationTestSet.java b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/BaseClusterIntegrationTestSet.java index cbea94920d8..e7c5af0dab9 100644 --- a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/BaseClusterIntegrationTestSet.java +++ b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/BaseClusterIntegrationTestSet.java @@ -273,6 +273,15 @@ private void testHardcodedQueriesCommon() h2Query = "SELECT DaysSinceEpoch - 25, COUNT(*) FROM mytable " + "GROUP BY DaysSinceEpoch " + "ORDER BY COUNT(*), DaysSinceEpoch DESC"; testQuery(query, h2Query); + + // Test aggregation functions in a CaseWhen statement + query = "SELECT AirlineID, " + + "CASE WHEN Sum(ArrDelay) < 0 THEN 0 WHEN SUM(ArrDelay) > 0 THEN SUM(ArrDelay) END AS SumArrDelay " + + "FROM mytable GROUP BY AirlineID"; + testQuery(query); + query = "SELECT CASE WHEN Sum(ArrDelay) < 0 THEN 0 WHEN SUM(ArrDelay) > 0 THEN SUM(ArrDelay) END AS SumArrDelay " + + "FROM mytable"; + testQuery(query); } private void testHardcodedQueriesV2() diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/QueryTestSet.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/QueryTestSet.java index 9f59aa37d85..fb86ab53f7e 100644 --- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/QueryTestSet.java +++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/QueryTestSet.java @@ -106,6 +106,9 @@ public Object[][] provideTestSql() { + " SUM(CASE WHEN a.col2 <> 'foo' AND a.col2 <> 'alice' THEN 1 ELSE 0 END) as unmatch_sum " + " FROM a WHERE a.ts >= 1600000000 GROUP BY a.col1"}, + new Object[]{"SELECT a.col1, CASE WHEN sum(a.col3) = 0 THEN 0 ELSE SUM(a.col3) END AS match_sum " + + " FROM a WHERE a.ts >= 1600000000 GROUP BY a.col1"}, + new Object[]{"SELECT a.col1, b.col2 FROM a JOIN b ON a.col1 = b.col1 " + " WHERE a.col3 IN (1, 2, 3) OR (a.col3 > 10 AND a.col3 < 50)"},