Skip to content

Commit

Permalink
[feature](nereids) in predicate extract non constant expressions (#46794
Browse files Browse the repository at this point in the history
)

Problem Summary:
if an in predicate contains non-literal, backend process it will reduce
performance. so we need to extract the non constant from the in
predicate.

this pr add an expression rewrite rule InPredicateExtractNonConstant, it
will extract all the non-constant out of the in predicate. for example:

```
k1  in (k2,  k3 + 3,   1, 2, 3 + 3)  => k1 in (1, 2, 3 + 3) or k1 = k2 or k1 = k3 + 1
```
  • Loading branch information
yujun777 authored Jan 13, 2025
1 parent 84d77a6 commit 565edd9
Show file tree
Hide file tree
Showing 10 changed files with 164 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.doris.nereids.rules.expression.rules.DigitalMaskingConvert;
import org.apache.doris.nereids.rules.expression.rules.FoldConstantRule;
import org.apache.doris.nereids.rules.expression.rules.InPredicateDedup;
import org.apache.doris.nereids.rules.expression.rules.InPredicateExtractNonConstant;
import org.apache.doris.nereids.rules.expression.rules.InPredicateToEqualToRule;
import org.apache.doris.nereids.rules.expression.rules.MedianConvert;
import org.apache.doris.nereids.rules.expression.rules.MergeDateTrunc;
Expand All @@ -47,6 +48,7 @@ public class ExpressionNormalization extends ExpressionRewrite {
SupportJavaDateFormatter.INSTANCE,
NormalizeBinaryPredicatesRule.INSTANCE,
InPredicateDedup.INSTANCE,
InPredicateExtractNonConstant.INSTANCE,
InPredicateToEqualToRule.INSTANCE,
SimplifyNotExprRule.INSTANCE,
SimplifyArithmeticRule.INSTANCE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public enum ExpressionRuleType {
FOLD_CONSTANT_ON_BE,
FOLD_CONSTANT_ON_FE,
IN_PREDICATE_DEDUP,
IN_PREDICATE_EXTRACT_NON_CONSTANT,
IN_PREDICATE_TO_EQUAL_TO,
LIKE_TO_EQUAL,
MERGE_DATE_TRUNC,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,29 +36,29 @@
public class InPredicateDedup implements ExpressionPatternRuleFactory {
public static final InPredicateDedup INSTANCE = new InPredicateDedup();

// In many BI scenarios, the sql is auto-generated, and hence there may be thousands of options.
// It takes a long time to apply this rule. So set a threshold for the max number.
public static final int REWRITE_OPTIONS_MAX_SIZE = 200;

@Override
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
return ImmutableList.of(
matchesType(InPredicate.class).then(InPredicateDedup::dedup)
matchesType(InPredicate.class)
.when(inPredicate -> inPredicate.getOptions().size() <= REWRITE_OPTIONS_MAX_SIZE)
.then(InPredicateDedup::dedup)
.toRule(ExpressionRuleType.IN_PREDICATE_DEDUP)
);
}

/** dedup */
public static Expression dedup(InPredicate inPredicate) {
// In many BI scenarios, the sql is auto-generated, and hence there may be thousands of options.
// It takes a long time to apply this rule. So set a threshold for the max number.
int optionSize = inPredicate.getOptions().size();
if (optionSize > 200) {
return inPredicate;
}
ImmutableSet.Builder<Expression> newOptionsBuilder = ImmutableSet.builderWithExpectedSize(inPredicate.arity());
for (Expression option : inPredicate.getOptions()) {
newOptionsBuilder.add(option);
}

Set<Expression> newOptions = newOptionsBuilder.build();
if (newOptions.size() == optionSize) {
if (newOptions.size() == inPredicate.getOptions().size()) {
return inPredicate;
}
return new InPredicate(inPredicate.getCompareExpr(), newOptions);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.expression.rules;

import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
import org.apache.doris.nereids.rules.expression.ExpressionRuleType;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.util.ExpressionUtils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Sets;
import org.apache.hadoop.util.Lists;

import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

/**
* Extract non-constant of InPredicate, For example:
* where k1 in (k2, k3, 10, 20, 30) ==> where k1 in (10, 20, 30) or k1 = k2 or k1 = k3.
* It's because backend handle in predicate which contains none-constant column will reduce performance.
*/
public class InPredicateExtractNonConstant implements ExpressionPatternRuleFactory {
public static final InPredicateExtractNonConstant INSTANCE = new InPredicateExtractNonConstant();

@Override
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
return ImmutableList.of(
matchesType(InPredicate.class)
.when(inPredicate -> inPredicate.getOptions().size()
<= InPredicateDedup.REWRITE_OPTIONS_MAX_SIZE)
.then(this::rewrite)
.toRule(ExpressionRuleType.IN_PREDICATE_EXTRACT_NON_CONSTANT)
);
}

private Expression rewrite(InPredicate inPredicate) {
Set<Expression> nonConstants = Sets.newLinkedHashSetWithExpectedSize(inPredicate.arity());
for (Expression option : inPredicate.getOptions()) {
if (!option.isConstant()) {
nonConstants.add(option);
}
}
if (nonConstants.isEmpty()) {
return inPredicate;
}
Expression key = inPredicate.getCompareExpr();
List<Expression> disjunctions = Lists.newArrayListWithExpectedSize(inPredicate.getOptions().size());
List<Expression> constants = inPredicate.getOptions().stream().filter(Expression::isConstant)
.collect(Collectors.toList());
if (!constants.isEmpty()) {
disjunctions.add(ExpressionUtils.toInPredicateOrEqualTo(key, constants));
}
for (Expression option : nonConstants) {
disjunctions.add(new EqualTo(key, option));
}
return ExpressionUtils.or(disjunctions);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,9 @@
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.util.ExpressionUtils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;

import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
Expand Down Expand Up @@ -69,8 +67,6 @@ public enum Mode {
public static final OrToIn EXTRACT_MODE_INSTANCE = new OrToIn(Mode.extractMode);
public static final OrToIn REPLACE_MODE_INSTANCE = new OrToIn(Mode.replaceMode);

public static final int REWRITE_OR_TO_IN_PREDICATE_THRESHOLD = 2;

private final Mode mode;

public OrToIn(Mode mode) {
Expand Down Expand Up @@ -196,18 +192,9 @@ private Map<Expression, Set<Literal>> mergeCandidates(
}

private Expression candidatesToFinalResult(Map<Expression, Set<Literal>> candidates) {
List<Expression> conjuncts = new ArrayList<>();
for (Expression key : candidates.keySet()) {
Set<Literal> literals = candidates.get(key);
if (literals.size() < REWRITE_OR_TO_IN_PREDICATE_THRESHOLD) {
for (Literal literal : literals) {
conjuncts.add(new EqualTo(key, literal));
}
} else {
conjuncts.add(new InPredicate(key, ImmutableList.copyOf(literals)));
}
}
return ExpressionUtils.and(conjuncts);
return ExpressionUtils.and(candidates.entrySet().stream()
.map(entry -> ExpressionUtils.toInPredicateOrEqualTo(entry.getKey(), entry.getValue()))
.collect(Collectors.toList()));
}

/*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ public ValueDesc visitEqualTo(EqualTo equalTo, ExpressionRewriteContext context)
@Override
public ValueDesc visitInPredicate(InPredicate inPredicate, ExpressionRewriteContext context) {
// only handle `NumericType` and `DateLikeType`
if (ExpressionUtils.isAllNonNullLiteral(inPredicate.getOptions())
if (inPredicate.getOptions().size() <= InPredicateDedup.REWRITE_OPTIONS_MAX_SIZE
&& ExpressionUtils.isAllNonNullLiteral(inPredicate.getOptions())
&& (ExpressionUtils.matchNumericType(inPredicate.getOptions())
|| ExpressionUtils.matchDateLikeType(inPredicate.getOptions()))) {
return ValueDesc.discrete(context, inPredicate);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,11 @@
import org.apache.doris.nereids.rules.expression.rules.RangeInference.UnknownValue;
import org.apache.doris.nereids.rules.expression.rules.RangeInference.ValueDesc;
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.LessThan;
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.util.ExpressionUtils;

Expand All @@ -45,9 +42,7 @@
import com.google.common.collect.Range;
import org.apache.commons.lang3.NotImplementedException;

import java.util.Iterator;
import java.util.List;
import java.util.Set;

/**
* This class implements the function to simplify expression range.
Expand Down Expand Up @@ -133,20 +128,7 @@ private Expression getExpression(RangeValue value) {
}

private Expression getExpression(DiscreteValue value) {
Expression reference = value.getReference();
Set<Literal> values = value.getValues();
// NOTICE: it's related with `InPredicateToEqualToRule`
// They are same processes, so must change synchronously.
if (values.size() == 1) {
return new EqualTo(reference, values.iterator().next());

// this condition should as same as OrToIn, or else meet dead loop
} else if (values.size() < OrToIn.REWRITE_OR_TO_IN_PREDICATE_THRESHOLD) {
Iterator<Literal> iterator = values.iterator();
return new Or(new EqualTo(reference, iterator.next()), new EqualTo(reference, iterator.next()));
} else {
return new InPredicate(reference, Lists.newArrayList(values));
}
return ExpressionUtils.toInPredicateOrEqualTo(value.getReference(), value.getValues());
}

private Expression getExpression(UnknownValue value) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,14 @@ public static Expression trueOrNull(Expression expression) {
}
}

public static Expression toInPredicateOrEqualTo(Expression reference, Collection<? extends Expression> values) {
if (values.size() < 2) {
return or(values.stream().map(value -> new EqualTo(reference, value)).collect(Collectors.toList()));
} else {
return new InPredicate(reference, ImmutableList.copyOf(values));
}
}

/**
* Use AND/OR to combine expressions together.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.doris.nereids.rules.expression.rules.DistinctPredicatesRule;
import org.apache.doris.nereids.rules.expression.rules.ExtractCommonFactorRule;
import org.apache.doris.nereids.rules.expression.rules.InPredicateDedup;
import org.apache.doris.nereids.rules.expression.rules.InPredicateExtractNonConstant;
import org.apache.doris.nereids.rules.expression.rules.InPredicateToEqualToRule;
import org.apache.doris.nereids.rules.expression.rules.NormalizeBinaryPredicatesRule;
import org.apache.doris.nereids.rules.expression.rules.SimplifyCastRule;
Expand Down Expand Up @@ -361,4 +362,18 @@ void testSimplifyRangeAndAddMinMax() {
assertRewriteAfterTypeCoercion("TA between 10 and 20 and TB between 10 and 20 or TA between 30 and 40 and TB between 30 and 40 or TA between 60 and 50 and TB between 60 and 50",
"(TA <= 20 and TB <= 20 or TA >= 30 and TB >= 30 or TA is null and null and TB is null) and TA >= 10 and TA <= 40 and TB >= 10 and TB <= 40");
}

@Test
public void testInPredicateExtractNonConstant() {
executor = new ExpressionRuleExecutor(ImmutableList.of(
bottomUp(
InPredicateExtractNonConstant.INSTANCE
)
));

assertRewriteAfterTypeCoercion("TA in (3, 2, 1)", "TA in (3, 2, 1)");
assertRewriteAfterTypeCoercion("TA in (TB, TC, TB)", "TA = TB or TA = TC");
assertRewriteAfterTypeCoercion("TA in (3, 2, 1, TB, TC, TB)", "TA in (3, 2, 1) or TA = TB or TA = TC");
assertRewriteAfterTypeCoercion("IA in (1 + 2, 2 + 3, 3 + TB)", "IA in (cast(1 + 2 as int), cast(2 + 3 as int)) or IA = cast(3 + TB as int)");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.expression.rules;

import org.apache.doris.nereids.sqltest.SqlTestBase;
import org.apache.doris.nereids.util.PlanChecker;

import org.junit.jupiter.api.Test;

class InPredicateExtractNonConstantTest extends SqlTestBase {
@Test
public void testExtractNonConstant() {
connectContext.getSessionVariable().setDisableNereidsRules("PRUNE_EMPTY_PARTITION");
String sql = "select * from T1 where id in (score, score, score + 100)";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matches(
logicalFilter().when(f -> f.getPredicate().toString().equals(
"OR[(id#0 = score#1),(id#0 = (score#1 + 100))]"
)));

sql = "select * from T1 where id in (score, score + 10, score + score, score, 10, 20, 30, 100 + 200)";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matches(
logicalFilter().when(f -> f.getPredicate().toString().equals(
"OR[id#0 IN (20, 10, 300, 30),(id#0 = score#1),(id#0 = (score#1 + 10)),(id#0 = (score#1 + score#1))]"
)));
}
}

0 comments on commit 565edd9

Please sign in to comment.