Skip to content

Commit

Permalink
Allow Aggregations in Case Expressions (#12613)
Browse files Browse the repository at this point in the history
  • Loading branch information
ege-st authored Mar 12, 2024
1 parent da2604b commit e7220fa
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -747,18 +747,10 @@ private static Expression toExpression(SqlNode node) {
for (int i = 0; i < whenOperands.size(); i++) {
SqlNode whenSqlNode = whenOperands.get(i);
Expression whenExpression = toExpression(whenSqlNode);
if (isAggregateExpression(whenExpression)) {
throw new SqlCompilationException(
"Aggregation functions inside WHEN Clause is not supported - " + whenSqlNode);
}
caseFuncExpr.getFunctionCall().addToOperands(whenExpression);

SqlNode thenSqlNode = thenOperands.get(i);
Expression thenExpression = toExpression(thenSqlNode);
if (isAggregateExpression(thenExpression)) {
throw new SqlCompilationException(
"Aggregation functions inside THEN Clause is not supported - " + thenSqlNode);
}
caseFuncExpr.getFunctionCall().addToOperands(thenExpression);
}
Expression elseExpression = toExpression(elseOperand);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,25 +141,57 @@ public void testCaseWhenStatements() {
Assert.assertEquals(caseFunc.getOperands().get(6).getLiteral().getFieldValue(), 0L);
}

@Test(expectedExceptions = SqlCompilationException.class)
public void testInvalidCaseWhenStatements() {
// Not support Aggregation functions in case statements.
try {
//@formatter:off
CalciteSqlParser.compileToPinotQuery(
"SELECT OrderID, Quantity,\n"
+ "CASE\n"
+ " WHEN sum(Quantity) > 30 THEN 'The quantity is greater than 30'\n"
+ " WHEN sum(Quantity) = 30 THEN 'The quantity is 30'\n"
+ " ELSE 'The quantity is under 30'\n"
+ "END AS QuantityText\n"
+ "FROM OrderDetails");
//@formatter:on
} catch (SqlCompilationException e) {
Assert.assertEquals(e.getMessage(),
"Aggregation functions inside WHEN Clause is not supported - SUM(`Quantity`) > 30");
throw e;
}
@Test
public void testAggregationInCaseWhenStatementsWithGroupBy() {
//@formatter:off
PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(
"SELECT OrderID, SUM(Quantity),\n"
+ "CASE\n"
+ " WHEN sum(Quantity) > 30 THEN 'The quantity is greater than 30'\n"
+ " WHEN sum(Quantity) = 30 THEN 'The quantity is 30'\n"
+ " ELSE 'The quantity is under 30'\n"
+ "END AS QuantityText\n"
+ "FROM OrderDetails\n"
+ "GROUP BY OrderID");
//@formatter:on
Function caseStm = pinotQuery.getSelectList().get(2).getFunctionCall().getOperands().get(0).getFunctionCall();
Assert.assertEquals(caseStm.getOperator(), "case");
Expression firstWhen = caseStm.getOperands().get(0);
Assert.assertEquals(firstWhen.getFunctionCall().getOperands().get(0).getFunctionCall().getOperator(), "sum");
Assert.assertEquals(
firstWhen.getFunctionCall().getOperands().get(0).getFunctionCall().getOperands().get(0).getIdentifier()
.getName(), "Quantity");
Expression secondWhen = caseStm.getOperands().get(2);
Assert.assertEquals(secondWhen.getFunctionCall().getOperands().get(0).getFunctionCall().getOperator(), "sum");
Assert.assertEquals(
secondWhen.getFunctionCall().getOperands().get(0).getFunctionCall().getOperands().get(0).getIdentifier()
.getName(), "Quantity");
}

@Test
public void testAggregationInCaseWhenStatements() {
//@formatter:off
PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(
"SELECT sum(Quantity),\n"
+ "CASE\n"
+ " WHEN sum(Quantity) > 30 THEN 'The quantity is greater than 30'\n"
+ " WHEN sum(Quantity) = 30 THEN 'The quantity is 30'\n"
+ " ELSE 'The quantity is under 30'\n"
+ "END AS QuantityText\n"
+ "FROM OrderDetails\n");
//@formatter:on
Function caseStm = pinotQuery.getSelectList().get(1).getFunctionCall().getOperands().get(0).getFunctionCall();
Assert.assertEquals(caseStm.getOperator(), "case");
Expression firstWhen = caseStm.getOperands().get(0);
Assert.assertEquals(firstWhen.getFunctionCall().getOperands().get(0).getFunctionCall().getOperator(), "sum");
Assert.assertEquals(
firstWhen.getFunctionCall().getOperands().get(0).getFunctionCall().getOperands().get(0).getIdentifier()
.getName(), "Quantity");
Expression secondWhen = caseStm.getOperands().get(2);
Assert.assertEquals(secondWhen.getFunctionCall().getOperands().get(0).getFunctionCall().getOperator(), "sum");
Assert.assertEquals(
secondWhen.getFunctionCall().getOperands().get(0).getFunctionCall().getOperands().get(0).getIdentifier()
.getName(), "Quantity");
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,15 @@ private void testHardcodedQueriesCommon()
h2Query = "SELECT DaysSinceEpoch - 25, COUNT(*) FROM mytable " + "GROUP BY DaysSinceEpoch "
+ "ORDER BY COUNT(*), DaysSinceEpoch DESC";
testQuery(query, h2Query);

// Test aggregation functions in a CaseWhen statement
query = "SELECT AirlineID, "
+ "CASE WHEN Sum(ArrDelay) < 0 THEN 0 WHEN SUM(ArrDelay) > 0 THEN SUM(ArrDelay) END AS SumArrDelay "
+ "FROM mytable GROUP BY AirlineID";
testQuery(query);
query = "SELECT CASE WHEN Sum(ArrDelay) < 0 THEN 0 WHEN SUM(ArrDelay) > 0 THEN SUM(ArrDelay) END AS SumArrDelay "
+ "FROM mytable";
testQuery(query);
}

private void testHardcodedQueriesV2()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ public Object[][] provideTestSql() {
+ " SUM(CASE WHEN a.col2 <> 'foo' AND a.col2 <> 'alice' THEN 1 ELSE 0 END) as unmatch_sum "
+ " FROM a WHERE a.ts >= 1600000000 GROUP BY a.col1"},

new Object[]{"SELECT a.col1, CASE WHEN sum(a.col3) = 0 THEN 0 ELSE SUM(a.col3) END AS match_sum "
+ " FROM a WHERE a.ts >= 1600000000 GROUP BY a.col1"},

new Object[]{"SELECT a.col1, b.col2 FROM a JOIN b ON a.col1 = b.col1 "
+ " WHERE a.col3 IN (1, 2, 3) OR (a.col3 > 10 AND a.col3 < 50)"},

Expand Down

0 comments on commit e7220fa

Please sign in to comment.