diff --git a/core/src/main/java/org/apache/calcite/sql/util/SqlVisitor.java b/core/src/main/java/org/apache/calcite/sql/util/SqlVisitor.java index 0bcd91e9373..635454d1410 100644 --- a/core/src/main/java/org/apache/calcite/sql/util/SqlVisitor.java +++ b/core/src/main/java/org/apache/calcite/sql/util/SqlVisitor.java @@ -26,6 +26,8 @@ import org.apache.calcite.sql.SqlNodeList; import org.apache.calcite.sql.SqlOperator; +import java.util.List; + /** * Visitor class, follows the * {@link org.apache.calcite.util.Glossary#VISITOR_PATTERN visitor pattern}. @@ -103,4 +105,9 @@ public interface SqlVisitor { default R visitNode(SqlNode n) { return n.accept(this); } + + /** Visits all nodes in a list. */ + default void visitAll(List selectList) { + selectList.forEach(e -> e.accept(this)); + } } diff --git a/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java b/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java index 89516750f35..c22d17328c0 100644 --- a/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java +++ b/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java @@ -84,7 +84,6 @@ import org.apache.calcite.rex.RexLambdaRef; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; -import org.apache.calcite.rex.RexOver; import org.apache.calcite.rex.RexPatternFieldRef; import org.apache.calcite.rex.RexRangeRef; import org.apache.calcite.rex.RexShuttle; @@ -222,7 +221,6 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static org.apache.calcite.linq4j.Nullness.castNonNull; -import static org.apache.calcite.runtime.FlatLists.append; import static org.apache.calcite.sql.SqlUtil.containsDefault; import static org.apache.calcite.sql.SqlUtil.containsIn; import static org.apache.calcite.sql.SqlUtil.stripAs; @@ -233,7 +231,11 @@ import static org.apache.calcite.sql.type.SqlTypeUtil.isExactNumeric; import static org.apache.calcite.sql.type.SqlTypeUtil.keepSourceTypeAndTargetNullability; import static org.apache.calcite.sql.type.SqlTypeUtil.promoteToRowType; +import static org.apache.calcite.sql.validate.SqlValidatorUtil.addAlias; import static org.apache.calcite.util.Static.RESOURCE; +import static org.apache.calcite.util.Util.first; +import static org.apache.calcite.util.Util.last; +import static org.apache.calcite.util.Util.skipLast; import static org.apache.calcite.util.Util.transform; import static java.util.Objects.requireNonNull; @@ -803,20 +805,8 @@ protected void convertSelectImpl( final RelCollation collation = cluster.traitSet().canonize(RelCollations.of(collationList)); - if (validator().isAggregate(select)) { - convertAgg( - bb, - select, - orderExprList); - } else { - convertSelectList( - bb, - measureBb, - select, - orderExprList); - } - - convertQualify(bb, select.getQualify()); + convertSelectList(bb, measureBb, select, orderExprList, ImmutableList.of(), + select.getQualify()); if (select.isDistinct()) { distinctify(bb, true); @@ -1174,8 +1164,7 @@ private void replaceSubQueries( } private void substituteSubQuery(Blackboard bb, SubQuery subQuery) { - final RexNode expr = subQuery.expr; - if (expr != null) { + if (subQuery.expr != null) { // Already done. return; } @@ -2933,7 +2922,7 @@ protected void convertCollectionTable( TableExpressionFactory expressionFunction = clazz -> Schemas.getTableExpression( requireNonNull(schema, "schema").plus(), - Util.last(udf.getNameAsId().names), table, clazz); + last(udf.getNameAsId().names), table, clazz); RelOptTable relOptTable = RelOptTableImpl.create(null, rowType, udf.getNameAsId().names, table, expressionFunction); @@ -3496,53 +3485,52 @@ private static JoinRelType convertJoinType(JoinType joinType) { * @param bb Scope within which to resolve identifiers * @param select Query * @param orderExprList Additional expressions needed to implement ORDER BY + * @param extraList Additional expressions needed to implement QUALIFY */ protected void convertAgg(Blackboard bb, SqlSelect select, - List orderExprList) { + List orderExprList, List extraList) { requireNonNull(bb.root, "bb.root"); - SqlNodeList groupList = select.getGroup(); - SqlNodeList selectList = select.getSelectList(); - SqlNode having = select.getHaving(); + List groupList = first(select.getGroup(), ImmutableList.of()); + List selectList = select.getSelectList(); + @Nullable SqlNode having = select.getHaving(); final AggConverter aggConverter = AggConverter.create(bb, (AggregatingSelectScope) validator().getSelectScope(select)); createAggImpl(bb, aggConverter, selectList, groupList, having, - orderExprList); + orderExprList, extraList); } private void createAggImpl(Blackboard bb, final AggConverter aggConverter, - SqlNodeList selectList, - @Nullable SqlNodeList groupList, + List selectList, + List groupList, @Nullable SqlNode having, - List orderExprList) { + List orderExprList, + List extraList) { // Find aggregate functions in SELECT and HAVING clause final AggregateFinder aggregateFinder = new AggregateFinder(); - selectList.accept(aggregateFinder); + aggregateFinder.visitAll(selectList); if (having != null) { having.accept(aggregateFinder); } // first replace the sub-queries inside the aggregates // because they will provide input rows to the aggregates. - replaceSubQueries(bb, aggregateFinder.list, - RelOptUtil.Logic.TRUE_FALSE_UNKNOWN); + aggregateFinder.list.forEach(e -> + replaceSubQueries(bb, e, RelOptUtil.Logic.TRUE_FALSE_UNKNOWN)); // also replace sub-queries inside filters in the aggregates - replaceSubQueries(bb, aggregateFinder.filterList, - RelOptUtil.Logic.TRUE_FALSE_UNKNOWN); + aggregateFinder.filterList.forEach(e -> + replaceSubQueries(bb, e, RelOptUtil.Logic.TRUE_FALSE_UNKNOWN)); // also replace sub-queries inside ordering spec in the aggregates - replaceSubQueries(bb, aggregateFinder.orderList, - RelOptUtil.Logic.TRUE_FALSE_UNKNOWN); - - // If group-by clause is missing, pretend that it has zero elements. - if (groupList == null) { - groupList = SqlNodeList.EMPTY; - } + aggregateFinder.orderList.forEach(e -> + replaceSubQueries(bb, e, RelOptUtil.Logic.TRUE_FALSE_UNKNOWN)); - replaceSubQueries(bb, groupList, RelOptUtil.Logic.TRUE_FALSE_UNKNOWN); + // also replace sub-queries inside GROUP BY + groupList.forEach(e -> + replaceSubQueries(bb, e, RelOptUtil.Logic.TRUE_FALSE_UNKNOWN)); // register the group exprs @@ -3567,9 +3555,11 @@ private void createAggImpl(Blackboard bb, // convert the select and having expressions, so that the // agg converter knows which aggregations are required - selectList.accept(aggConverter); - // Assert we don't have dangling items left in the stack - assert !aggConverter.inOver; + for (SqlNode expr : selectList) { + expr.accept(aggConverter); + // Assert we don't have dangling items left in the stack + assert !aggConverter.inOver; + } for (SqlNode expr : orderExprList) { expr.accept(aggConverter); assert !aggConverter.inOver; @@ -3578,6 +3568,10 @@ private void createAggImpl(Blackboard bb, having.accept(aggConverter); assert !aggConverter.inOver; } + for (SqlNode expr : extraList) { + expr.accept(aggConverter); + assert !aggConverter.inOver; + } // compute inputs to the aggregator final PairList preExprs; @@ -3621,10 +3615,9 @@ private void createAggImpl(Blackboard bb, // Tell bb which of group columns are sorted. bb.columnMonotonicities.clear(); - for (SqlNode groupItem : groupList) { - bb.columnMonotonicities.add( - bb.scope.getMonotonicity(groupItem)); - } + groupList.forEach(e -> + bb.columnMonotonicities.add( + bb.scope.getMonotonicity(e))); // Add the aggregator bb.setRoot( @@ -3646,7 +3639,8 @@ private void createAggImpl(Blackboard bb, // This needs to be done separately from the sub-query inside // any aggregate in the select list, and after the aggregate rel // is allocated. - replaceSubQueries(bb, selectList, RelOptUtil.Logic.TRUE_FALSE_UNKNOWN); + selectList.forEach(e -> + replaceSubQueries(bb, e, RelOptUtil.Logic.TRUE_FALSE_UNKNOWN)); // Now sub-queries in the entire select list have been converted. // Convert the select expressions to get the final list to be @@ -3677,6 +3671,10 @@ private void createAggImpl(Blackboard bb, projects.add(bb.convertExpression(expr), SqlValidatorUtil.alias(expr, k++)); } + for (SqlNode expr : extraList) { + projects.add(bb.convertExpression(expr), + SqlValidatorUtil.alias(expr, k++)); + } } finally { bb.agg = null; } @@ -3702,10 +3700,9 @@ private void createAggImpl(Blackboard bb, // Tell bb which of group columns are sorted. bb.columnMonotonicities.clear(); - for (SqlNode selectItem : selectList) { - bb.columnMonotonicities.add( - bb.scope.getMonotonicity(selectItem)); - } + selectList.forEach(e -> + bb.columnMonotonicities.add( + bb.scope.getMonotonicity(e))); } /** @@ -4717,19 +4714,49 @@ private RelNode convertMultisets(final List operands, return ret; } - private void convertSelectList( + /** Converts a select-list for an aggregate or non-aggregate query, + * adding a filter for a QUALIFY clause if present. */ + private void convertSelectList(Blackboard bb, Blackboard measureBb, + SqlSelect select, List orderExprList, List extraList, + @Nullable SqlNode qualify) { + if (qualify != null) { + // Add the QUALIFY expression to the select-list, + // convert the extended list, + // filter by the last field, + // and remove the last field. + convertSelectList(bb, measureBb, select, orderExprList, + ImmutableList.of(addAlias(qualify, "QualifyExpression")), null); + bb.setRoot( + relBuilder.push(bb.root()) + .filter(last(relBuilder.fields()))// filter on last column + .project(skipLast(relBuilder.fields())) // remove last column + .build(), + false); + return; + } + + if (validator().isAggregate(select)) { + convertAgg(bb, select, orderExprList, extraList); + } else { + convertNonAggregateSelectList(bb, measureBb, select, orderExprList, + extraList); + } + } + + /** Converts the select-list of a non-aggregate query. */ + private void convertNonAggregateSelectList( Blackboard bb, Blackboard measureBb, SqlSelect select, - List orderList) { - SqlNodeList selectList = select.getSelectList(); - selectList = validator().expandStar(selectList, select, false); + List orderList, + List extraList) { + final SqlNodeList selectList = + validator().expandStar(select.getSelectList(), select, false); - replaceSubQueries(bb, selectList, RelOptUtil.Logic.TRUE_FALSE_UNKNOWN); - replaceSubQueries(bb, new SqlNodeList(orderList, SqlParserPos.ZERO), - RelOptUtil.Logic.TRUE_FALSE_UNKNOWN); + Iterables.concat(selectList, orderList, extraList).forEach(e -> + replaceSubQueries(bb, e, RelOptUtil.Logic.TRUE_FALSE_UNKNOWN)); - List fieldNames = new ArrayList<>(); + final List fieldNames = new ArrayList<>(); final List exprs = new ArrayList<>(); final Collection aliases = new TreeSet<>(); @@ -4764,7 +4791,7 @@ private void convertSelectList( fieldNames.add(deriveAlias(expr, aliases, i)); } - // Project extra fields for sorting. + // Project extra fields for ORDER BY. for (SqlNode expr : orderList) { ++i; SqlNode expr2 = validator().expandOrderExpr(select, expr); @@ -4772,12 +4799,19 @@ private void convertSelectList( fieldNames.add(deriveAlias(expr, aliases, i)); } - fieldNames = + // Project extra fields for QUALIFY. + for (SqlNode expr : extraList) { + ++i; + exprs.add(bb.convertExpression(expr)); + fieldNames.add(deriveAlias(expr, aliases, i)); + } + + final List uniqueFieldNames = SqlValidatorUtil.uniquify(fieldNames, catalogReader.nameMatcher().isCaseSensitive()); relBuilder.push(bb.root()) - .projectNamed(exprs, fieldNames, true); + .projectNamed(exprs, uniqueFieldNames, true); RelNode project = relBuilder.build(); @@ -4789,7 +4823,8 @@ private void convertSelectList( // in p.r instead of the original exprs Project project1 = (Project) p.r; r = relBuilder.push(bb.root()) - .projectNamed(project1.getProjects(), fieldNames, true, ImmutableSet.of(p.id)) + .projectNamed(project1.getProjects(), uniqueFieldNames, true, + ImmutableSet.of(p.id)) .build(); } else { r = project; @@ -4846,113 +4881,6 @@ private static String deriveAlias( return alias; } - private void convertQualify(Blackboard bb, @Nullable SqlNode qualify) { - if (qualify == null) { - return; - } - - final LogicalProject projectionFromSelect = - requireNonNull((LogicalProject) bb.root, "root"); - - // Convert qualify SqlNode to a RexNode - replaceSubQueries(bb, qualify, RelOptUtil.Logic.UNKNOWN_AS_FALSE); - final RelNode originalRoot = requireNonNull(bb.root, "root"); - RexNode qualifyRexNode; - try { - // Set the root to the input of the project, - // since QUALIFY might have an expression in the OVER clause - // that references a column not in the SELECT. - bb.setRoot(projectionFromSelect.getInput(), false); - qualifyRexNode = bb.convertExpression(qualify); - } finally { - bb.setRoot(originalRoot, false); - } - - // Check to see if the qualify expression has a referenced expression and - // do some referencing accordingly - final RexNode qualifyWithReferencesRexNode = - qualifyRexNode.accept( - new DuplicateEliminator(projectionFromSelect.getProjects())); - - // Create a Project with the QUALIFY expression - if (qualifyWithReferencesRexNode.equals(qualifyRexNode)) { - // The QUALIFY expression does not depend on any references like so: - // - // SELECT A, B - // FROM tbl - // QUALIFY WINDOW(C) = 1 - // - // Meaning we should generate a plan like: - // Project(A, B, WINDOW(C) = 1 as QualifyExpression) - // TableScan(tbl) - // - relBuilder.push(projectionFromSelect.getInput()) - .project( - append(projectionFromSelect.getProjects(), qualifyRexNode), - append(projectionFromSelect.getRowType().getFieldNames(), - "QualifyExpression")); - } else { - // The QUALIFY expression depended on a reference meaning - // we need to introduce an extra project like so: - // - // SELECT A, B, WINDOW(C) as window_val - // FROM tbl - // QUALIFY window_val = 1 - // - // Meaning we should generate a plan like: - // - // Project($0, $1, $2, =($2, 1) as QualifyExpression) - // Project(A, B, WINDOW(C) as window_val) - // TableScan(tbl) - // - // This is a very specific application of Common Subexpression Elimination - // (CSE), since the window value pops up twice. - relBuilder.push(requireNonNull(bb.root, "root")) - .project( - append(relBuilder.fields(), qualifyWithReferencesRexNode), - append(relBuilder.peek().getRowType().getFieldNames(), - "QualifyExpression")); - } - - // Filter on that extra column - relBuilder.filter(Util.last(relBuilder.fields())); - - // Remove that extra column from the projection - relBuilder.project( - Util.first(relBuilder.fields(), - projectionFromSelect.getProjects().size())); - - // Update the root - bb.setRoot(relBuilder.build(), false); - } - - /** Eliminates a common sub-expression by looking for a {@link RexNode} - * in the expressions of a {@link Project}; if found, returns a refIndex - * instead of the raw node. */ - private static final class DuplicateEliminator extends RexShuttle { - private final List projects; - - DuplicateEliminator(List projects) { - this.projects = projects; - } - - @Override public RexNode visitCall(RexCall call) { - final int i = projects.indexOf(call); - if (i >= 0) { - return new RexInputRef(i, projects.get(i).getType()); - } - return super.visitCall(call); - } - - @Override public RexNode visitOver(RexOver over) { - final int i = projects.indexOf(over); - if (i >= 0) { - return new RexInputRef(i, projects.get(i).getType()); - } - return over; - } - } - /** * Converts a WITH sub-query into a relational expression. */ @@ -6342,10 +6270,10 @@ public static class SqlIdentifierFinder implements SqlVisitor { * Visitor that collects all aggregate functions in a {@link SqlNode} tree. */ private static class AggregateFinder extends SqlBasicVisitor { - final SqlNodeList list = new SqlNodeList(SqlParserPos.ZERO); - final SqlNodeList filterList = new SqlNodeList(SqlParserPos.ZERO); - final SqlNodeList distinctList = new SqlNodeList(SqlParserPos.ZERO); - final SqlNodeList orderList = new SqlNodeList(SqlParserPos.ZERO); + final List list = new ArrayList<>(); + final List filterList = new ArrayList<>(); + final List distinctList = new ArrayList<>(); + final List orderList = new ArrayList<>(); @Override public Void visit(SqlCall call) { // ignore window aggregates and ranking functions (associated with OVER operator) @@ -6368,7 +6296,7 @@ private static class AggregateFinder extends SqlBasicVisitor { final SqlNodeList distinctList = (SqlNodeList) call.getOperandList().get(1); list.add(aggCall); - this.distinctList.addAll(distinctList.getList()); + this.distinctList.addAll(distinctList); return null; } diff --git a/core/src/test/java/org/apache/calcite/test/SqlToRelConverterTest.java b/core/src/test/java/org/apache/calcite/test/SqlToRelConverterTest.java index cdec09e5ee1..90ed3b69eb7 100644 --- a/core/src/test/java/org/apache/calcite/test/SqlToRelConverterTest.java +++ b/core/src/test/java/org/apache/calcite/test/SqlToRelConverterTest.java @@ -3640,6 +3640,30 @@ void checkCorrelatedMapSubQuery(boolean expand) { .ok(); } + /** Test case for + * [CALCITE-6691] + * QUALIFY on subquery that projects. */ + @Test void testQualifyOnProject() { + sql("WITH t0 AS (SELECT deptno, sal FROM emp),\n" + + "t1 AS (SELECT deptno\n" + + " FROM t0\n" + + " QUALIFY row_number() OVER (PARTITION BY deptno\n" + + " ORDER BY sal DESC) = 1)\n" + + "SELECT deptno FROM t1") + .ok(); + } + + @Test void testQualifyAfterGroupBy() { + sql("WITH t0 AS (SELECT deptno, sal FROM emp),\n" + + "t1 AS (SELECT deptno, sal, COUNT(*)\n" + + " FROM t0\n" + + " GROUP BY deptno, sal\n" + + " QUALIFY row_number() OVER (PARTITION BY deptno\n" + + " ORDER BY COUNT(*) DESC) = 1)\n" + + "SELECT deptno FROM t1") + .ok(); + } + @Test void testQualifyWithWindowClause() { sql("SELECT empno, ename, SUM(deptno) OVER myWindow as sumDeptNo\n" + "FROM emp\n" diff --git a/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml b/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml index 5195579f8f5..98ab6f3fe53 100644 --- a/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml +++ b/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml @@ -6393,6 +6393,27 @@ LogicalProject(EMPNO=[$0], DEPTNO=[$1], DEPTNO0=[$3], NAME=[$4]) LogicalProject(EMPNO=[$0], DEPTNO=[$7], $f2=[+($7, 20)]) LogicalTableScan(table=[[CATALOG, SALES, EMP]]) LogicalTableScan(table=[[CATALOG, SALES, DEPT]]) +]]> + + + + + + + + @@ -6465,6 +6486,24 @@ LogicalProject(EMPNO=[$0], ENAME=[$1], DEPTNO=[$2]) LogicalFilter(condition=[$3]) LogicalProject(EMPNO=[$0], ENAME=[$1], DEPTNO=[$7], QualifyExpression=[=(RANK() OVER (PARTITION BY $1 ORDER BY $7 DESC), 1)]) LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + + + + + + + + @@ -6513,12 +6552,11 @@ LIMIT 5]]> LogicalSort(sort0=[$2], dir0=[ASC], fetch=[5]) LogicalAggregate(group=[{0, 1, 2, 3}]) LogicalProject(EMPNO=[$0], ENAME=[$1], DEPTNO=[$2], RANK_VAL=[$3]) - LogicalFilter(condition=[$5]) - LogicalProject(EMPNO=[$0], ENAME=[$1], DEPTNO=[$2], RANK_VAL=[$3], EXPR$0=[$4], QualifyExpression=[=($3, $4)]) + LogicalFilter(condition=[$4]) + LogicalProject(EMPNO=[$0], ENAME=[$1], DEPTNO=[$7], RANK_VAL=[RANK() OVER (PARTITION BY $1 ORDER BY $7 DESC)], QualifyExpression=[=(RANK() OVER (PARTITION BY $1 ORDER BY $7 DESC), $9)]) LogicalJoin(condition=[true], joinType=[left]) - LogicalProject(EMPNO=[$0], ENAME=[$1], DEPTNO=[$7], RANK_VAL=[RANK() OVER (PARTITION BY $1 ORDER BY $7 DESC)]) - LogicalFilter(condition=[>($5, 1000)]) - LogicalTableScan(table=[[CATALOG, SALES, EMP]]) + LogicalFilter(condition=[>($5, 1000)]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) LogicalAggregate(group=[{}], EXPR$0=[COUNT()]) LogicalProject($f0=[0]) LogicalTableScan(table=[[CATALOG, SALES, EMP]]) @@ -6568,11 +6606,10 @@ QUALIFY rank_val = (SELECT COUNT(*) FROM emp)]]>