From fae49bc06a9dac705fb266b7883d5f951cb4da57 Mon Sep 17 00:00:00 2001 From: Julian Hyde Date: Wed, 20 Nov 2024 15:44:27 -0800 Subject: [PATCH] [CALCITE-6691] QUALIFY on project references wrong columns The bug was that we were translating the QUALIFY expression after the SELECT clause, and the mapping was wrong if columns had moved around of it the expressions were non-trivial. The fix is to convert the QUALIFY expression at the same time as the HAVING expression. Also add support for a QUALIFY clause in a query with a GROUP BY. In this case, the QUALIFY expression may reference aggregate functions. Previously such queries would give a ClassCastException (trying to convert a LogicalAggregate to a LogicalProject). If a QUALIFY expression references or duplicates an expression in the SELECT clause, we no longer detect and deduplicate that. This has made one or two plans more verbose. Potentially we would add back deduplication. Or just let Calc do common subexpression elimination later in the planning process, as it always has done. Also refactor SqlToRelConverter. Previously we were creating some `SqlNodeList` wrappers only to use their `accept(SqlVisitor)` method; now we leave them as `List` and call new method `SqlVisitor.visitAll(List)`. Close apache/calcite#4061 --- .../apache/calcite/sql/util/SqlVisitor.java | 7 + .../calcite/sql2rel/SqlToRelConverter.java | 280 +++++++----------- .../calcite/test/SqlToRelConverterTest.java | 24 ++ .../calcite/test/SqlToRelConverterTest.xml | 55 +++- core/src/test/resources/sql/qualify.iq | 31 ++ 5 files changed, 212 insertions(+), 185 deletions(-) diff --git a/core/src/main/java/org/apache/calcite/sql/util/SqlVisitor.java b/core/src/main/java/org/apache/calcite/sql/util/SqlVisitor.java index 0bcd91e9373..635454d1410 100644 --- a/core/src/main/java/org/apache/calcite/sql/util/SqlVisitor.java +++ b/core/src/main/java/org/apache/calcite/sql/util/SqlVisitor.java @@ -26,6 +26,8 @@ import org.apache.calcite.sql.SqlNodeList; import org.apache.calcite.sql.SqlOperator; +import java.util.List; + /** * Visitor class, follows the * {@link org.apache.calcite.util.Glossary#VISITOR_PATTERN visitor pattern}. @@ -103,4 +105,9 @@ public interface SqlVisitor { default R visitNode(SqlNode n) { return n.accept(this); } + + /** Visits all nodes in a list. */ + default void visitAll(List selectList) { + selectList.forEach(e -> e.accept(this)); + } } diff --git a/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java b/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java index 89516750f35..c22d17328c0 100644 --- a/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java +++ b/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java @@ -84,7 +84,6 @@ import org.apache.calcite.rex.RexLambdaRef; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; -import org.apache.calcite.rex.RexOver; import org.apache.calcite.rex.RexPatternFieldRef; import org.apache.calcite.rex.RexRangeRef; import org.apache.calcite.rex.RexShuttle; @@ -222,7 +221,6 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static org.apache.calcite.linq4j.Nullness.castNonNull; -import static org.apache.calcite.runtime.FlatLists.append; import static org.apache.calcite.sql.SqlUtil.containsDefault; import static org.apache.calcite.sql.SqlUtil.containsIn; import static org.apache.calcite.sql.SqlUtil.stripAs; @@ -233,7 +231,11 @@ import static org.apache.calcite.sql.type.SqlTypeUtil.isExactNumeric; import static org.apache.calcite.sql.type.SqlTypeUtil.keepSourceTypeAndTargetNullability; import static org.apache.calcite.sql.type.SqlTypeUtil.promoteToRowType; +import static org.apache.calcite.sql.validate.SqlValidatorUtil.addAlias; import static org.apache.calcite.util.Static.RESOURCE; +import static org.apache.calcite.util.Util.first; +import static org.apache.calcite.util.Util.last; +import static org.apache.calcite.util.Util.skipLast; import static org.apache.calcite.util.Util.transform; import static java.util.Objects.requireNonNull; @@ -803,20 +805,8 @@ protected void convertSelectImpl( final RelCollation collation = cluster.traitSet().canonize(RelCollations.of(collationList)); - if (validator().isAggregate(select)) { - convertAgg( - bb, - select, - orderExprList); - } else { - convertSelectList( - bb, - measureBb, - select, - orderExprList); - } - - convertQualify(bb, select.getQualify()); + convertSelectList(bb, measureBb, select, orderExprList, ImmutableList.of(), + select.getQualify()); if (select.isDistinct()) { distinctify(bb, true); @@ -1174,8 +1164,7 @@ private void replaceSubQueries( } private void substituteSubQuery(Blackboard bb, SubQuery subQuery) { - final RexNode expr = subQuery.expr; - if (expr != null) { + if (subQuery.expr != null) { // Already done. return; } @@ -2933,7 +2922,7 @@ protected void convertCollectionTable( TableExpressionFactory expressionFunction = clazz -> Schemas.getTableExpression( requireNonNull(schema, "schema").plus(), - Util.last(udf.getNameAsId().names), table, clazz); + last(udf.getNameAsId().names), table, clazz); RelOptTable relOptTable = RelOptTableImpl.create(null, rowType, udf.getNameAsId().names, table, expressionFunction); @@ -3496,53 +3485,52 @@ private static JoinRelType convertJoinType(JoinType joinType) { * @param bb Scope within which to resolve identifiers * @param select Query * @param orderExprList Additional expressions needed to implement ORDER BY + * @param extraList Additional expressions needed to implement QUALIFY */ protected void convertAgg(Blackboard bb, SqlSelect select, - List orderExprList) { + List orderExprList, List extraList) { requireNonNull(bb.root, "bb.root"); - SqlNodeList groupList = select.getGroup(); - SqlNodeList selectList = select.getSelectList(); - SqlNode having = select.getHaving(); + List groupList = first(select.getGroup(), ImmutableList.of()); + List selectList = select.getSelectList(); + @Nullable SqlNode having = select.getHaving(); final AggConverter aggConverter = AggConverter.create(bb, (AggregatingSelectScope) validator().getSelectScope(select)); createAggImpl(bb, aggConverter, selectList, groupList, having, - orderExprList); + orderExprList, extraList); } private void createAggImpl(Blackboard bb, final AggConverter aggConverter, - SqlNodeList selectList, - @Nullable SqlNodeList groupList, + List selectList, + List groupList, @Nullable SqlNode having, - List orderExprList) { + List orderExprList, + List extraList) { // Find aggregate functions in SELECT and HAVING clause final AggregateFinder aggregateFinder = new AggregateFinder(); - selectList.accept(aggregateFinder); + aggregateFinder.visitAll(selectList); if (having != null) { having.accept(aggregateFinder); } // first replace the sub-queries inside the aggregates // because they will provide input rows to the aggregates. - replaceSubQueries(bb, aggregateFinder.list, - RelOptUtil.Logic.TRUE_FALSE_UNKNOWN); + aggregateFinder.list.forEach(e -> + replaceSubQueries(bb, e, RelOptUtil.Logic.TRUE_FALSE_UNKNOWN)); // also replace sub-queries inside filters in the aggregates - replaceSubQueries(bb, aggregateFinder.filterList, - RelOptUtil.Logic.TRUE_FALSE_UNKNOWN); + aggregateFinder.filterList.forEach(e -> + replaceSubQueries(bb, e, RelOptUtil.Logic.TRUE_FALSE_UNKNOWN)); // also replace sub-queries inside ordering spec in the aggregates - replaceSubQueries(bb, aggregateFinder.orderList, - RelOptUtil.Logic.TRUE_FALSE_UNKNOWN); - - // If group-by clause is missing, pretend that it has zero elements. - if (groupList == null) { - groupList = SqlNodeList.EMPTY; - } + aggregateFinder.orderList.forEach(e -> + replaceSubQueries(bb, e, RelOptUtil.Logic.TRUE_FALSE_UNKNOWN)); - replaceSubQueries(bb, groupList, RelOptUtil.Logic.TRUE_FALSE_UNKNOWN); + // also replace sub-queries inside GROUP BY + groupList.forEach(e -> + replaceSubQueries(bb, e, RelOptUtil.Logic.TRUE_FALSE_UNKNOWN)); // register the group exprs @@ -3567,9 +3555,11 @@ private void createAggImpl(Blackboard bb, // convert the select and having expressions, so that the // agg converter knows which aggregations are required - selectList.accept(aggConverter); - // Assert we don't have dangling items left in the stack - assert !aggConverter.inOver; + for (SqlNode expr : selectList) { + expr.accept(aggConverter); + // Assert we don't have dangling items left in the stack + assert !aggConverter.inOver; + } for (SqlNode expr : orderExprList) { expr.accept(aggConverter); assert !aggConverter.inOver; @@ -3578,6 +3568,10 @@ private void createAggImpl(Blackboard bb, having.accept(aggConverter); assert !aggConverter.inOver; } + for (SqlNode expr : extraList) { + expr.accept(aggConverter); + assert !aggConverter.inOver; + } // compute inputs to the aggregator final PairList preExprs; @@ -3621,10 +3615,9 @@ private void createAggImpl(Blackboard bb, // Tell bb which of group columns are sorted. bb.columnMonotonicities.clear(); - for (SqlNode groupItem : groupList) { - bb.columnMonotonicities.add( - bb.scope.getMonotonicity(groupItem)); - } + groupList.forEach(e -> + bb.columnMonotonicities.add( + bb.scope.getMonotonicity(e))); // Add the aggregator bb.setRoot( @@ -3646,7 +3639,8 @@ private void createAggImpl(Blackboard bb, // This needs to be done separately from the sub-query inside // any aggregate in the select list, and after the aggregate rel // is allocated. - replaceSubQueries(bb, selectList, RelOptUtil.Logic.TRUE_FALSE_UNKNOWN); + selectList.forEach(e -> + replaceSubQueries(bb, e, RelOptUtil.Logic.TRUE_FALSE_UNKNOWN)); // Now sub-queries in the entire select list have been converted. // Convert the select expressions to get the final list to be @@ -3677,6 +3671,10 @@ private void createAggImpl(Blackboard bb, projects.add(bb.convertExpression(expr), SqlValidatorUtil.alias(expr, k++)); } + for (SqlNode expr : extraList) { + projects.add(bb.convertExpression(expr), + SqlValidatorUtil.alias(expr, k++)); + } } finally { bb.agg = null; } @@ -3702,10 +3700,9 @@ private void createAggImpl(Blackboard bb, // Tell bb which of group columns are sorted. bb.columnMonotonicities.clear(); - for (SqlNode selectItem : selectList) { - bb.columnMonotonicities.add( - bb.scope.getMonotonicity(selectItem)); - } + selectList.forEach(e -> + bb.columnMonotonicities.add( + bb.scope.getMonotonicity(e))); } /** @@ -4717,19 +4714,49 @@ private RelNode convertMultisets(final List operands, return ret; } - private void convertSelectList( + /** Converts a select-list for an aggregate or non-aggregate query, + * adding a filter for a QUALIFY clause if present. */ + private void convertSelectList(Blackboard bb, Blackboard measureBb, + SqlSelect select, List orderExprList, List extraList, + @Nullable SqlNode qualify) { + if (qualify != null) { + // Add the QUALIFY expression to the select-list, + // convert the extended list, + // filter by the last field, + // and remove the last field. + convertSelectList(bb, measureBb, select, orderExprList, + ImmutableList.of(addAlias(qualify, "QualifyExpression")), null); + bb.setRoot( + relBuilder.push(bb.root()) + .filter(last(relBuilder.fields()))// filter on last column + .project(skipLast(relBuilder.fields())) // remove last column + .build(), + false); + return; + } + + if (validator().isAggregate(select)) { + convertAgg(bb, select, orderExprList, extraList); + } else { + convertNonAggregateSelectList(bb, measureBb, select, orderExprList, + extraList); + } + } + + /** Converts the select-list of a non-aggregate query. */ + private void convertNonAggregateSelectList( Blackboard bb, Blackboard measureBb, SqlSelect select, - List orderList) { - SqlNodeList selectList = select.getSelectList(); - selectList = validator().expandStar(selectList, select, false); + List orderList, + List extraList) { + final SqlNodeList selectList = + validator().expandStar(select.getSelectList(), select, false); - replaceSubQueries(bb, selectList, RelOptUtil.Logic.TRUE_FALSE_UNKNOWN); - replaceSubQueries(bb, new SqlNodeList(orderList, SqlParserPos.ZERO), - RelOptUtil.Logic.TRUE_FALSE_UNKNOWN); + Iterables.concat(selectList, orderList, extraList).forEach(e -> + replaceSubQueries(bb, e, RelOptUtil.Logic.TRUE_FALSE_UNKNOWN)); - List fieldNames = new ArrayList<>(); + final List fieldNames = new ArrayList<>(); final List exprs = new ArrayList<>(); final Collection aliases = new TreeSet<>(); @@ -4764,7 +4791,7 @@ private void convertSelectList( fieldNames.add(deriveAlias(expr, aliases, i)); } - // Project extra fields for sorting. + // Project extra fields for ORDER BY. for (SqlNode expr : orderList) { ++i; SqlNode expr2 = validator().expandOrderExpr(select, expr); @@ -4772,12 +4799,19 @@ private void convertSelectList( fieldNames.add(deriveAlias(expr, aliases, i)); } - fieldNames = + // Project extra fields for QUALIFY. + for (SqlNode expr : extraList) { + ++i; + exprs.add(bb.convertExpression(expr)); + fieldNames.add(deriveAlias(expr, aliases, i)); + } + + final List uniqueFieldNames = SqlValidatorUtil.uniquify(fieldNames, catalogReader.nameMatcher().isCaseSensitive()); relBuilder.push(bb.root()) - .projectNamed(exprs, fieldNames, true); + .projectNamed(exprs, uniqueFieldNames, true); RelNode project = relBuilder.build(); @@ -4789,7 +4823,8 @@ private void convertSelectList( // in p.r instead of the original exprs Project project1 = (Project) p.r; r = relBuilder.push(bb.root()) - .projectNamed(project1.getProjects(), fieldNames, true, ImmutableSet.of(p.id)) + .projectNamed(project1.getProjects(), uniqueFieldNames, true, + ImmutableSet.of(p.id)) .build(); } else { r = project; @@ -4846,113 +4881,6 @@ private static String deriveAlias( return alias; } - private void convertQualify(Blackboard bb, @Nullable SqlNode qualify) { - if (qualify == null) { - return; - } - - final LogicalProject projectionFromSelect = - requireNonNull((LogicalProject) bb.root, "root"); - - // Convert qualify SqlNode to a RexNode - replaceSubQueries(bb, qualify, RelOptUtil.Logic.UNKNOWN_AS_FALSE); - final RelNode originalRoot = requireNonNull(bb.root, "root"); - RexNode qualifyRexNode; - try { - // Set the root to the input of the project, - // since QUALIFY might have an expression in the OVER clause - // that references a column not in the SELECT. - bb.setRoot(projectionFromSelect.getInput(), false); - qualifyRexNode = bb.convertExpression(qualify); - } finally { - bb.setRoot(originalRoot, false); - } - - // Check to see if the qualify expression has a referenced expression and - // do some referencing accordingly - final RexNode qualifyWithReferencesRexNode = - qualifyRexNode.accept( - new DuplicateEliminator(projectionFromSelect.getProjects())); - - // Create a Project with the QUALIFY expression - if (qualifyWithReferencesRexNode.equals(qualifyRexNode)) { - // The QUALIFY expression does not depend on any references like so: - // - // SELECT A, B - // FROM tbl - // QUALIFY WINDOW(C) = 1 - // - // Meaning we should generate a plan like: - // Project(A, B, WINDOW(C) = 1 as QualifyExpression) - // TableScan(tbl) - // - relBuilder.push(projectionFromSelect.getInput()) - .project( - append(projectionFromSelect.getProjects(), qualifyRexNode), - append(projectionFromSelect.getRowType().getFieldNames(), - "QualifyExpression")); - } else { - // The QUALIFY expression depended on a reference meaning - // we need to introduce an extra project like so: - // - // SELECT A, B, WINDOW(C) as window_val - // FROM tbl - // QUALIFY window_val = 1 - // - // Meaning we should generate a plan like: - // - // Project($0, $1, $2, =($2, 1) as QualifyExpression) - // Project(A, B, WINDOW(C) as window_val) - // TableScan(tbl) - // - // This is a very specific application of Common Subexpression Elimination - // (CSE), since the window value pops up twice. - relBuilder.push(requireNonNull(bb.root, "root")) - .project( - append(relBuilder.fields(), qualifyWithReferencesRexNode), - append(relBuilder.peek().getRowType().getFieldNames(), - "QualifyExpression")); - } - - // Filter on that extra column - relBuilder.filter(Util.last(relBuilder.fields())); - - // Remove that extra column from the projection - relBuilder.project( - Util.first(relBuilder.fields(), - projectionFromSelect.getProjects().size())); - - // Update the root - bb.setRoot(relBuilder.build(), false); - } - - /** Eliminates a common sub-expression by looking for a {@link RexNode} - * in the expressions of a {@link Project}; if found, returns a refIndex - * instead of the raw node. */ - private static final class DuplicateEliminator extends RexShuttle { - private final List projects; - - DuplicateEliminator(List projects) { - this.projects = projects; - } - - @Override public RexNode visitCall(RexCall call) { - final int i = projects.indexOf(call); - if (i >= 0) { - return new RexInputRef(i, projects.get(i).getType()); - } - return super.visitCall(call); - } - - @Override public RexNode visitOver(RexOver over) { - final int i = projects.indexOf(over); - if (i >= 0) { - return new RexInputRef(i, projects.get(i).getType()); - } - return over; - } - } - /** * Converts a WITH sub-query into a relational expression. */ @@ -6342,10 +6270,10 @@ public static class SqlIdentifierFinder implements SqlVisitor { * Visitor that collects all aggregate functions in a {@link SqlNode} tree. */ private static class AggregateFinder extends SqlBasicVisitor { - final SqlNodeList list = new SqlNodeList(SqlParserPos.ZERO); - final SqlNodeList filterList = new SqlNodeList(SqlParserPos.ZERO); - final SqlNodeList distinctList = new SqlNodeList(SqlParserPos.ZERO); - final SqlNodeList orderList = new SqlNodeList(SqlParserPos.ZERO); + final List list = new ArrayList<>(); + final List filterList = new ArrayList<>(); + final List distinctList = new ArrayList<>(); + final List orderList = new ArrayList<>(); @Override public Void visit(SqlCall call) { // ignore window aggregates and ranking functions (associated with OVER operator) @@ -6368,7 +6296,7 @@ private static class AggregateFinder extends SqlBasicVisitor { final SqlNodeList distinctList = (SqlNodeList) call.getOperandList().get(1); list.add(aggCall); - this.distinctList.addAll(distinctList.getList()); + this.distinctList.addAll(distinctList); return null; } diff --git a/core/src/test/java/org/apache/calcite/test/SqlToRelConverterTest.java b/core/src/test/java/org/apache/calcite/test/SqlToRelConverterTest.java index cdec09e5ee1..90ed3b69eb7 100644 --- a/core/src/test/java/org/apache/calcite/test/SqlToRelConverterTest.java +++ b/core/src/test/java/org/apache/calcite/test/SqlToRelConverterTest.java @@ -3640,6 +3640,30 @@ void checkCorrelatedMapSubQuery(boolean expand) { .ok(); } + /** Test case for + * [CALCITE-6691] + * QUALIFY on subquery that projects. */ + @Test void testQualifyOnProject() { + sql("WITH t0 AS (SELECT deptno, sal FROM emp),\n" + + "t1 AS (SELECT deptno\n" + + " FROM t0\n" + + " QUALIFY row_number() OVER (PARTITION BY deptno\n" + + " ORDER BY sal DESC) = 1)\n" + + "SELECT deptno FROM t1") + .ok(); + } + + @Test void testQualifyAfterGroupBy() { + sql("WITH t0 AS (SELECT deptno, sal FROM emp),\n" + + "t1 AS (SELECT deptno, sal, COUNT(*)\n" + + " FROM t0\n" + + " GROUP BY deptno, sal\n" + + " QUALIFY row_number() OVER (PARTITION BY deptno\n" + + " ORDER BY COUNT(*) DESC) = 1)\n" + + "SELECT deptno FROM t1") + .ok(); + } + @Test void testQualifyWithWindowClause() { sql("SELECT empno, ename, SUM(deptno) OVER myWindow as sumDeptNo\n" + "FROM emp\n" diff --git a/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml b/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml index 5195579f8f5..98ab6f3fe53 100644 --- a/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml +++ b/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml @@ -6393,6 +6393,27 @@ LogicalProject(EMPNO=[$0], DEPTNO=[$1], DEPTNO0=[$3], NAME=[$4]) LogicalProject(EMPNO=[$0], DEPTNO=[$7], $f2=[+($7, 20)]) LogicalTableScan(table=[[CATALOG, SALES, EMP]]) LogicalTableScan(table=[[CATALOG, SALES, DEPT]]) +]]> + + + + + + + + @@ -6465,6 +6486,24 @@ LogicalProject(EMPNO=[$0], ENAME=[$1], DEPTNO=[$2]) LogicalFilter(condition=[$3]) LogicalProject(EMPNO=[$0], ENAME=[$1], DEPTNO=[$7], QualifyExpression=[=(RANK() OVER (PARTITION BY $1 ORDER BY $7 DESC), 1)]) LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + + + + + + + + @@ -6513,12 +6552,11 @@ LIMIT 5]]> LogicalSort(sort0=[$2], dir0=[ASC], fetch=[5]) LogicalAggregate(group=[{0, 1, 2, 3}]) LogicalProject(EMPNO=[$0], ENAME=[$1], DEPTNO=[$2], RANK_VAL=[$3]) - LogicalFilter(condition=[$5]) - LogicalProject(EMPNO=[$0], ENAME=[$1], DEPTNO=[$2], RANK_VAL=[$3], EXPR$0=[$4], QualifyExpression=[=($3, $4)]) + LogicalFilter(condition=[$4]) + LogicalProject(EMPNO=[$0], ENAME=[$1], DEPTNO=[$7], RANK_VAL=[RANK() OVER (PARTITION BY $1 ORDER BY $7 DESC)], QualifyExpression=[=(RANK() OVER (PARTITION BY $1 ORDER BY $7 DESC), $9)]) LogicalJoin(condition=[true], joinType=[left]) - LogicalProject(EMPNO=[$0], ENAME=[$1], DEPTNO=[$7], RANK_VAL=[RANK() OVER (PARTITION BY $1 ORDER BY $7 DESC)]) - LogicalFilter(condition=[>($5, 1000)]) - LogicalTableScan(table=[[CATALOG, SALES, EMP]]) + LogicalFilter(condition=[>($5, 1000)]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) LogicalAggregate(group=[{}], EXPR$0=[COUNT()]) LogicalProject($f0=[0]) LogicalTableScan(table=[[CATALOG, SALES, EMP]]) @@ -6568,11 +6606,10 @@ QUALIFY rank_val = (SELECT COUNT(*) FROM emp)]]>