Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Aggregations in Case Expressions #12613

Merged
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -747,18 +747,10 @@ private static Expression toExpression(SqlNode node) {
for (int i = 0; i < whenOperands.size(); i++) {
SqlNode whenSqlNode = whenOperands.get(i);
Expression whenExpression = toExpression(whenSqlNode);
if (isAggregateExpression(whenExpression)) {
throw new SqlCompilationException(
"Aggregation functions inside WHEN Clause is not supported - " + whenSqlNode);
}
caseFuncExpr.getFunctionCall().addToOperands(whenExpression);

SqlNode thenSqlNode = thenOperands.get(i);
Expression thenExpression = toExpression(thenSqlNode);
if (isAggregateExpression(thenExpression)) {
throw new SqlCompilationException(
"Aggregation functions inside THEN Clause is not supported - " + thenSqlNode);
}
caseFuncExpr.getFunctionCall().addToOperands(thenExpression);
}
Expression elseExpression = toExpression(elseOperand);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,25 +141,56 @@ public void testCaseWhenStatements() {
Assert.assertEquals(caseFunc.getOperands().get(6).getLiteral().getFieldValue(), 0L);
}

@Test(expectedExceptions = SqlCompilationException.class)
public void testInvalidCaseWhenStatements() {
gortiz marked this conversation as resolved.
Show resolved Hide resolved
// Not support Aggregation functions in case statements.
try {
//@formatter:off
CalciteSqlParser.compileToPinotQuery(
"SELECT OrderID, Quantity,\n"
+ "CASE\n"
+ " WHEN sum(Quantity) > 30 THEN 'The quantity is greater than 30'\n"
+ " WHEN sum(Quantity) = 30 THEN 'The quantity is 30'\n"
+ " ELSE 'The quantity is under 30'\n"
+ "END AS QuantityText\n"
+ "FROM OrderDetails");
@Test
public void testAggregationInCaseWhenStatementsWithGroupBy() {
//@formatter:off
PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(
"SELECT OrderID, SUM(Quantity),\n"
+ "CASE\n"
+ " WHEN sum(Quantity) > 30 THEN 'The quantity is greater than 30'\n"
+ " WHEN sum(Quantity) = 30 THEN 'The quantity is 30'\n"
+ " ELSE 'The quantity is under 30'\n"
+ "END AS QuantityText\n"
+ "FROM OrderDetails\n"
+ "GROUP BY OrderID");
//@formatter:on
Function caseStm = pinotQuery.getSelectList().get(2).getFunctionCall().getOperands().get(0).getFunctionCall();
Assert.assertEquals(caseStm.getOperator(), "case");
Expression firstWhen = caseStm.getOperands().get(0);
Assert.assertEquals(firstWhen.getFunctionCall().getOperands().get(0).getFunctionCall().getOperator(), "sum");
Assert.assertEquals(firstWhen.getFunctionCall().getOperands().get(0).getFunctionCall()
.getOperands().get(0).getIdentifier().getName(), "Quantity");

Expression secondWhen = caseStm.getOperands().get(2);
Assert.assertEquals(secondWhen.getFunctionCall().getOperands().get(0).getFunctionCall().getOperator(), "sum");
Assert.assertEquals(secondWhen.getFunctionCall().getOperands().get(0).getFunctionCall()
.getOperands().get(0).getIdentifier().getName(), "Quantity");
}

@Test
public void testAggregationInCaseWhenStatements() {
//@formatter:off
PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(
"SELECT sum(Quantity),\n"
+ "CASE\n"
+ " WHEN sum(Quantity) > 30 THEN 'The quantity is greater than 30'\n"
+ " WHEN sum(Quantity) = 30 THEN 'The quantity is 30'\n"
+ " ELSE 'The quantity is under 30'\n"
+ "END AS QuantityText\n"
+ "FROM OrderDetails\n");
//@formatter:on
} catch (SqlCompilationException e) {
Assert.assertEquals(e.getMessage(),
"Aggregation functions inside WHEN Clause is not supported - SUM(`Quantity`) > 30");
throw e;
}

Function caseStm = pinotQuery.getSelectList().get(1).getFunctionCall().getOperands().get(0).getFunctionCall();
Assert.assertEquals(caseStm.getOperator(), "case");
Expression firstWhen = caseStm.getOperands().get(0);
Assert.assertEquals(firstWhen.getFunctionCall().getOperands().get(0).getFunctionCall().getOperator(), "sum");
Assert.assertEquals(firstWhen.getFunctionCall().getOperands().get(0).getFunctionCall()
.getOperands().get(0).getIdentifier().getName(), "Quantity");

Expression secondWhen = caseStm.getOperands().get(2);
Assert.assertEquals(secondWhen.getFunctionCall().getOperands().get(0).getFunctionCall().getOperator(), "sum");
Assert.assertEquals(secondWhen.getFunctionCall().getOperands().get(0).getFunctionCall()
.getOperands().get(0).getIdentifier().getName(), "Quantity");
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,17 @@ private void testHardcodedQueriesCommon()
h2Query = "SELECT DaysSinceEpoch - 25, COUNT(*) FROM mytable " + "GROUP BY DaysSinceEpoch "
+ "ORDER BY COUNT(*), DaysSinceEpoch DESC";
testQuery(query, h2Query);

// Test aggregation functions in a CaseWhen statement
query =
"SELECT AirlineID, "
+ "CASE WHEN Sum(ArrDelay) < 0 THEN 0 WHEN SUM(ArrDelay) > 0 THEN SUM(ArrDelay) END AS SumArrDelay"
+ " FROM mytable GROUP BY AirlineID";
testQuery(query);
query =
"SELECT CASE WHEN Sum(ArrDelay) < 0 THEN 0 WHEN SUM(ArrDelay) > 0 THEN SUM(ArrDelay) END AS SumArrDelay"
+ " FROM mytable";
testQuery(query);
}

private void testHardcodedQueriesV2()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ public Object[][] provideTestSql() {
+ " SUM(CASE WHEN a.col2 <> 'foo' AND a.col2 <> 'alice' THEN 1 ELSE 0 END) as unmatch_sum "
+ " FROM a WHERE a.ts >= 1600000000 GROUP BY a.col1"},

new Object[]{"SELECT a.col1, CASE WHEN sum(a.col3) = 0 THEN 0 ELSE SUM(a.col3) END AS match_sum "
+ " FROM a WHERE a.ts >= 1600000000 GROUP BY a.col1"},
Comment on lines +109 to +110
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would like to add additional unit tests for V2 as I think this test just verifies Semantics and Syntax, would like to test execution results.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually we want to add tests for V1. It is already working in V2


new Object[]{"SELECT a.col1, b.col2 FROM a JOIN b ON a.col1 = b.col1 "
+ " WHERE a.col3 IN (1, 2, 3) OR (a.col3 > 10 AND a.col3 < 50)"},

Expand Down
Loading