diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java index b4430d33087e91..135f80111e980c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java @@ -22,6 +22,7 @@ import org.apache.doris.nereids.rules.expression.rules.DigitalMaskingConvert; import org.apache.doris.nereids.rules.expression.rules.FoldConstantRule; import org.apache.doris.nereids.rules.expression.rules.InPredicateDedup; +import org.apache.doris.nereids.rules.expression.rules.InPredicateExtractNonConstant; import org.apache.doris.nereids.rules.expression.rules.InPredicateToEqualToRule; import org.apache.doris.nereids.rules.expression.rules.MedianConvert; import org.apache.doris.nereids.rules.expression.rules.MergeDateTrunc; @@ -47,6 +48,7 @@ public class ExpressionNormalization extends ExpressionRewrite { SupportJavaDateFormatter.INSTANCE, NormalizeBinaryPredicatesRule.INSTANCE, InPredicateDedup.INSTANCE, + InPredicateExtractNonConstant.INSTANCE, InPredicateToEqualToRule.INSTANCE, SimplifyNotExprRule.INSTANCE, SimplifyArithmeticRule.INSTANCE, diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleType.java index bc12c0459eefb5..7f83ab8a090fd8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleType.java @@ -35,6 +35,7 @@ public enum ExpressionRuleType { FOLD_CONSTANT_ON_BE, FOLD_CONSTANT_ON_FE, IN_PREDICATE_DEDUP, + IN_PREDICATE_EXTRACT_NON_CONSTANT, IN_PREDICATE_TO_EQUAL_TO, LIKE_TO_EQUAL, MERGE_DATE_TRUNC, diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/InPredicateDedup.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/InPredicateDedup.java index aaa822ac691eb7..1be5971f6a28ae 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/InPredicateDedup.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/InPredicateDedup.java @@ -36,29 +36,29 @@ public class InPredicateDedup implements ExpressionPatternRuleFactory { public static final InPredicateDedup INSTANCE = new InPredicateDedup(); + // In many BI scenarios, the sql is auto-generated, and hence there may be thousands of options. + // It takes a long time to apply this rule. So set a threshold for the max number. + public static final int REWRITE_OPTIONS_MAX_SIZE = 200; + @Override public List> buildRules() { return ImmutableList.of( - matchesType(InPredicate.class).then(InPredicateDedup::dedup) + matchesType(InPredicate.class) + .when(inPredicate -> inPredicate.getOptions().size() <= REWRITE_OPTIONS_MAX_SIZE) + .then(InPredicateDedup::dedup) .toRule(ExpressionRuleType.IN_PREDICATE_DEDUP) ); } /** dedup */ public static Expression dedup(InPredicate inPredicate) { - // In many BI scenarios, the sql is auto-generated, and hence there may be thousands of options. - // It takes a long time to apply this rule. So set a threshold for the max number. - int optionSize = inPredicate.getOptions().size(); - if (optionSize > 200) { - return inPredicate; - } ImmutableSet.Builder newOptionsBuilder = ImmutableSet.builderWithExpectedSize(inPredicate.arity()); for (Expression option : inPredicate.getOptions()) { newOptionsBuilder.add(option); } Set newOptions = newOptionsBuilder.build(); - if (newOptions.size() == optionSize) { + if (newOptions.size() == inPredicate.getOptions().size()) { return inPredicate; } return new InPredicate(inPredicate.getCompareExpr(), newOptions); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/InPredicateExtractNonConstant.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/InPredicateExtractNonConstant.java new file mode 100644 index 00000000000000..c869dafa0a2d8a --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/InPredicateExtractNonConstant.java @@ -0,0 +1,77 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression.rules; + +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; +import org.apache.doris.nereids.rules.expression.ExpressionRuleType; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.InPredicate; +import org.apache.doris.nereids.util.ExpressionUtils; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Sets; +import org.apache.hadoop.util.Lists; + +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * Extract non-constant of InPredicate, For example: + * where k1 in (k2, k3, 10, 20, 30) ==> where k1 in (10, 20, 30) or k1 = k2 or k1 = k3. + * It's because backend handle in predicate which contains none-constant column will reduce performance. + */ +public class InPredicateExtractNonConstant implements ExpressionPatternRuleFactory { + public static final InPredicateExtractNonConstant INSTANCE = new InPredicateExtractNonConstant(); + + @Override + public List> buildRules() { + return ImmutableList.of( + matchesType(InPredicate.class) + .when(inPredicate -> inPredicate.getOptions().size() + <= InPredicateDedup.REWRITE_OPTIONS_MAX_SIZE) + .then(this::rewrite) + .toRule(ExpressionRuleType.IN_PREDICATE_EXTRACT_NON_CONSTANT) + ); + } + + private Expression rewrite(InPredicate inPredicate) { + Set nonConstants = Sets.newLinkedHashSetWithExpectedSize(inPredicate.arity()); + for (Expression option : inPredicate.getOptions()) { + if (!option.isConstant()) { + nonConstants.add(option); + } + } + if (nonConstants.isEmpty()) { + return inPredicate; + } + Expression key = inPredicate.getCompareExpr(); + List disjunctions = Lists.newArrayListWithExpectedSize(inPredicate.getOptions().size()); + List constants = inPredicate.getOptions().stream().filter(Expression::isConstant) + .collect(Collectors.toList()); + if (!constants.isEmpty()) { + disjunctions.add(ExpressionUtils.toInPredicateOrEqualTo(key, constants)); + } + for (Expression option : nonConstants) { + disjunctions.add(new EqualTo(key, option)); + } + return ExpressionUtils.or(disjunctions); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OrToIn.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OrToIn.java index 136b40af5847c9..16006c02080be4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OrToIn.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OrToIn.java @@ -31,11 +31,9 @@ import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.util.ExpressionUtils; -import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import com.google.common.collect.Sets; -import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.List; @@ -69,8 +67,6 @@ public enum Mode { public static final OrToIn EXTRACT_MODE_INSTANCE = new OrToIn(Mode.extractMode); public static final OrToIn REPLACE_MODE_INSTANCE = new OrToIn(Mode.replaceMode); - public static final int REWRITE_OR_TO_IN_PREDICATE_THRESHOLD = 2; - private final Mode mode; public OrToIn(Mode mode) { @@ -196,18 +192,9 @@ private Map> mergeCandidates( } private Expression candidatesToFinalResult(Map> candidates) { - List conjuncts = new ArrayList<>(); - for (Expression key : candidates.keySet()) { - Set literals = candidates.get(key); - if (literals.size() < REWRITE_OR_TO_IN_PREDICATE_THRESHOLD) { - for (Literal literal : literals) { - conjuncts.add(new EqualTo(key, literal)); - } - } else { - conjuncts.add(new InPredicate(key, ImmutableList.copyOf(literals))); - } - } - return ExpressionUtils.and(conjuncts); + return ExpressionUtils.and(candidates.entrySet().stream() + .map(entry -> ExpressionUtils.toInPredicateOrEqualTo(entry.getKey(), entry.getValue())) + .collect(Collectors.toList())); } /* diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangeInference.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangeInference.java index c78ec7a75fbad1..7c23ce36a3dbf4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangeInference.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangeInference.java @@ -110,7 +110,8 @@ public ValueDesc visitEqualTo(EqualTo equalTo, ExpressionRewriteContext context) @Override public ValueDesc visitInPredicate(InPredicate inPredicate, ExpressionRewriteContext context) { // only handle `NumericType` and `DateLikeType` - if (ExpressionUtils.isAllNonNullLiteral(inPredicate.getOptions()) + if (inPredicate.getOptions().size() <= InPredicateDedup.REWRITE_OPTIONS_MAX_SIZE + && ExpressionUtils.isAllNonNullLiteral(inPredicate.getOptions()) && (ExpressionUtils.matchNumericType(inPredicate.getOptions()) || ExpressionUtils.matchDateLikeType(inPredicate.getOptions()))) { return ValueDesc.discrete(context, inPredicate); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java index 64891882f7d661..6ac69f1eb56375 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java @@ -27,14 +27,11 @@ import org.apache.doris.nereids.rules.expression.rules.RangeInference.UnknownValue; import org.apache.doris.nereids.rules.expression.rules.RangeInference.ValueDesc; import org.apache.doris.nereids.trees.expressions.CompoundPredicate; -import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.GreaterThan; import org.apache.doris.nereids.trees.expressions.GreaterThanEqual; -import org.apache.doris.nereids.trees.expressions.InPredicate; import org.apache.doris.nereids.trees.expressions.LessThan; import org.apache.doris.nereids.trees.expressions.LessThanEqual; -import org.apache.doris.nereids.trees.expressions.Or; import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.util.ExpressionUtils; @@ -45,9 +42,7 @@ import com.google.common.collect.Range; import org.apache.commons.lang3.NotImplementedException; -import java.util.Iterator; import java.util.List; -import java.util.Set; /** * This class implements the function to simplify expression range. @@ -133,20 +128,7 @@ private Expression getExpression(RangeValue value) { } private Expression getExpression(DiscreteValue value) { - Expression reference = value.getReference(); - Set values = value.getValues(); - // NOTICE: it's related with `InPredicateToEqualToRule` - // They are same processes, so must change synchronously. - if (values.size() == 1) { - return new EqualTo(reference, values.iterator().next()); - - // this condition should as same as OrToIn, or else meet dead loop - } else if (values.size() < OrToIn.REWRITE_OR_TO_IN_PREDICATE_THRESHOLD) { - Iterator iterator = values.iterator(); - return new Or(new EqualTo(reference, iterator.next()), new EqualTo(reference, iterator.next())); - } else { - return new InPredicate(reference, Lists.newArrayList(values)); - } + return ExpressionUtils.toInPredicateOrEqualTo(value.getReference(), value.getValues()); } private Expression getExpression(UnknownValue value) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java index 723224409bfda2..3d8aef2c8428ab 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java @@ -276,6 +276,14 @@ public static Expression trueOrNull(Expression expression) { } } + public static Expression toInPredicateOrEqualTo(Expression reference, Collection values) { + if (values.size() < 2) { + return or(values.stream().map(value -> new EqualTo(reference, value)).collect(Collectors.toList())); + } else { + return new InPredicate(reference, ImmutableList.copyOf(values)); + } + } + /** * Use AND/OR to combine expressions together. */ diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java index 13f2789c0a9ffc..34b40efcc29315 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java @@ -21,6 +21,7 @@ import org.apache.doris.nereids.rules.expression.rules.DistinctPredicatesRule; import org.apache.doris.nereids.rules.expression.rules.ExtractCommonFactorRule; import org.apache.doris.nereids.rules.expression.rules.InPredicateDedup; +import org.apache.doris.nereids.rules.expression.rules.InPredicateExtractNonConstant; import org.apache.doris.nereids.rules.expression.rules.InPredicateToEqualToRule; import org.apache.doris.nereids.rules.expression.rules.NormalizeBinaryPredicatesRule; import org.apache.doris.nereids.rules.expression.rules.SimplifyCastRule; @@ -361,4 +362,18 @@ void testSimplifyRangeAndAddMinMax() { assertRewriteAfterTypeCoercion("TA between 10 and 20 and TB between 10 and 20 or TA between 30 and 40 and TB between 30 and 40 or TA between 60 and 50 and TB between 60 and 50", "(TA <= 20 and TB <= 20 or TA >= 30 and TB >= 30 or TA is null and null and TB is null) and TA >= 10 and TA <= 40 and TB >= 10 and TB <= 40"); } + + @Test + public void testInPredicateExtractNonConstant() { + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp( + InPredicateExtractNonConstant.INSTANCE + ) + )); + + assertRewriteAfterTypeCoercion("TA in (3, 2, 1)", "TA in (3, 2, 1)"); + assertRewriteAfterTypeCoercion("TA in (TB, TC, TB)", "TA = TB or TA = TC"); + assertRewriteAfterTypeCoercion("TA in (3, 2, 1, TB, TC, TB)", "TA in (3, 2, 1) or TA = TB or TA = TC"); + assertRewriteAfterTypeCoercion("IA in (1 + 2, 2 + 3, 3 + TB)", "IA in (cast(1 + 2 as int), cast(2 + 3 as int)) or IA = cast(3 + TB as int)"); + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/InPredicateExtractNonConstantTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/InPredicateExtractNonConstantTest.java new file mode 100644 index 00000000000000..60b511f44756eb --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/InPredicateExtractNonConstantTest.java @@ -0,0 +1,47 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression.rules; + +import org.apache.doris.nereids.sqltest.SqlTestBase; +import org.apache.doris.nereids.util.PlanChecker; + +import org.junit.jupiter.api.Test; + +class InPredicateExtractNonConstantTest extends SqlTestBase { + @Test + public void testExtractNonConstant() { + connectContext.getSessionVariable().setDisableNereidsRules("PRUNE_EMPTY_PARTITION"); + String sql = "select * from T1 where id in (score, score, score + 100)"; + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matches( + logicalFilter().when(f -> f.getPredicate().toString().equals( + "OR[(id#0 = score#1),(id#0 = (score#1 + 100))]" + ))); + + sql = "select * from T1 where id in (score, score + 10, score + score, score, 10, 20, 30, 100 + 200)"; + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matches( + logicalFilter().when(f -> f.getPredicate().toString().equals( + "OR[id#0 IN (20, 10, 300, 30),(id#0 = score#1),(id#0 = (score#1 + 10)),(id#0 = (score#1 + score#1))]" + ))); + } +}