Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Multi-stage] Clean up unnecessary checks in rules #14066

Merged
merged 1 commit into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ public class PinotAggregateExchangeNodeInsertRule extends RelOptRule {
new PinotAggregateExchangeNodeInsertRule(PinotRuleUtils.PINOT_REL_FACTORY);

public PinotAggregateExchangeNodeInsertRule(RelBuilderFactory factory) {
// NOTE: Explicitly match for LogicalAggregate because after applying the rule, LogicalAggregate is replaced with
// PinotLogicalAggregate, and the rule won't be applied again.
super(operand(LogicalAggregate.class, any()), factory, null);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinInfo;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.rules.CoreRules;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
Expand All @@ -40,8 +39,7 @@


/**
* SemiJoinRule that matches an Aggregate on top of a Join with an Aggregate
* as its right child.
* SemiJoinRule that matches an Aggregate on top of a Join with an Aggregate as its right child.
*
* @see CoreRules#PROJECT_TO_SEMI_JOIN
*/
Expand All @@ -50,18 +48,9 @@ public class PinotAggregateToSemiJoinRule extends RelOptRule {
new PinotAggregateToSemiJoinRule(PinotRuleUtils.PINOT_REL_FACTORY);

public PinotAggregateToSemiJoinRule(RelBuilderFactory factory) {
super(operand(LogicalAggregate.class, any()), factory, null);
}

@Override
@SuppressWarnings("rawtypes")
public boolean matches(RelOptRuleCall call) {
final Aggregate topAgg = call.rel(0);
if (!PinotRuleUtils.isJoin(topAgg.getInput())) {
return false;
}
final Join join = (Join) PinotRuleUtils.unboxRel(topAgg.getInput());
return PinotRuleUtils.isAggregate(join.getInput(1));
super(operand(Aggregate.class,
some(operand(Join.class, some(operand(RelNode.class, any()), operand(Aggregate.class, any()))))), factory,
null);
Comment on lines +51 to +53
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I would recommend to start using the new way to instantiate rules instead of this one, which is deprecated. See PinotImplicitTableHintRule in #13943.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I also noticed that. PinotSortExchangeCopyRule is also using the config in the constructor (I believe the benefit is to allow reusing the same rule to match different trees).
We can use a separate PR to clean this up.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've recently ended up reading Hive's CalcitePlanner class, which is pretty similar to our QueryEnvironment class and I think that is something we can get inspiration of once our rule logic starts to be more complex.

}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ private static LogicalProject constructNewProject(LogicalProject oldProject, Log
}
castedNewProjects.add(newNode);
}
return needCast ? LogicalProject.create(oldProject.getInput(), oldProject.getHints(), castedNewProjects,
return needCast ? oldProject.copy(oldProject.getTraitSet(), oldProject.getInput(), castedNewProjects,
oldProject.getRowType()) : newProject;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
*/
package org.apache.pinot.calcite.rel.rules;

import java.util.Collections;
import java.util.List;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.rel.RelDistribution;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Exchange;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.pinot.calcite.rel.logical.PinotLogicalExchange;

Expand All @@ -36,17 +37,14 @@ public class PinotExchangeEliminationRule extends RelOptRule {
new PinotExchangeEliminationRule(PinotRuleUtils.PINOT_REL_FACTORY);

public PinotExchangeEliminationRule(RelBuilderFactory factory) {
super(operand(PinotLogicalExchange.class,
some(operand(PinotLogicalExchange.class, some(operand(RelNode.class, any()))))), factory, null);
super(operand(Exchange.class, some(operand(Exchange.class, some(operand(RelNode.class, any()))))), factory, null);
}

@Override
public void onMatch(RelOptRuleCall call) {
PinotLogicalExchange exchange0 = call.rel(0);
PinotLogicalExchange exchange1 = call.rel(1);
Exchange exchange0 = call.rel(0);
RelNode input = call.rel(2);
// convert the call to skip the exchange.
RelNode rel = exchange0.copy(input.getTraitSet(), Collections.singletonList(input));
call.transformTo(rel);
call.transformTo(exchange0.copy(input.getTraitSet(), List.of(input)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.logical.LogicalFilter;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
Expand All @@ -33,27 +32,20 @@ public class PinotFilterExpandSearchRule extends RelOptRule {
new PinotFilterExpandSearchRule(PinotRuleUtils.PINOT_REL_FACTORY);

public PinotFilterExpandSearchRule(RelBuilderFactory factory) {
super(operand(LogicalFilter.class, any()), factory, null);
super(operand(Filter.class, any()), factory, null);
}

@Override
@SuppressWarnings("rawtypes")
public boolean matches(RelOptRuleCall call) {
if (call.rels.length < 1) {
return false;
}
if (call.rel(0) instanceof Filter) {
Filter filter = call.rel(0);
return containsRangeSearch(filter.getCondition());
}
return false;
Filter filter = call.rel(0);
return containsRangeSearch(filter.getCondition());
}

@Override
public void onMatch(RelOptRuleCall call) {
Filter filter = call.rel(0);
RexNode newCondition = RexUtil.expandSearch(filter.getCluster().getRexBuilder(), null, filter.getCondition());
call.transformTo(LogicalFilter.create(filter.getInput(), newCondition));
call.transformTo(filter.copy(filter.getTraitSet(), filter.getInput(), newCondition));
}

private boolean containsRangeSearch(RexNode condition) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,12 @@
*/
package org.apache.pinot.calcite.rel.rules;

import com.google.common.collect.ImmutableList;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.rel.RelDistributions;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinInfo;
import org.apache.calcite.rel.logical.LogicalJoin;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.pinot.calcite.rel.logical.PinotLogicalExchange;

Expand All @@ -38,19 +36,13 @@ public class PinotJoinExchangeNodeInsertRule extends RelOptRule {
new PinotJoinExchangeNodeInsertRule(PinotRuleUtils.PINOT_REL_FACTORY);

public PinotJoinExchangeNodeInsertRule(RelBuilderFactory factory) {
super(operand(LogicalJoin.class, any()), factory, null);
super(operand(Join.class, any()), factory, null);
}

@Override
public boolean matches(RelOptRuleCall call) {
if (call.rels.length < 1) {
return false;
}
if (call.rel(0) instanceof Join) {
Join join = call.rel(0);
return !PinotRuleUtils.isExchange(join.getLeft()) && !PinotRuleUtils.isExchange(join.getRight());
}
return false;
Join join = call.rel(0);
return !PinotRuleUtils.isExchange(join.getLeft()) && !PinotRuleUtils.isExchange(join.getRight());
}

@Override
Expand All @@ -73,10 +65,7 @@ public void onMatch(RelOptRuleCall call) {
rightExchange = PinotLogicalExchange.create(rightInput, RelDistributions.hash(joinInfo.rightKeys));
}

RelNode newJoinNode =
new LogicalJoin(join.getCluster(), join.getTraitSet(), join.getHints(), leftExchange, rightExchange,
join.getCondition(), join.getVariablesSet(), join.getJoinType(), join.isSemiJoinDone(),
ImmutableList.copyOf(join.getSystemFieldList()));
call.transformTo(newJoinNode);
call.transformTo(join.copy(join.getTraitSet(), join.getCondition(), leftExchange, rightExchange, join.getJoinType(),
join.isSemiJoinDone()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
*/
package org.apache.pinot.calcite.rel.rules;

import com.google.common.collect.ImmutableList;
import java.util.Collections;
import java.util.List;
import org.apache.calcite.plan.RelOptRule;
Expand All @@ -31,7 +30,6 @@
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinInfo;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.logical.LogicalJoin;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.pinot.calcite.rel.hint.PinotHintOptions;
import org.apache.pinot.calcite.rel.hint.PinotHintStrategyTable;
Expand Down Expand Up @@ -121,27 +119,23 @@ public class PinotJoinToDynamicBroadcastRule extends RelOptRule {
new PinotJoinToDynamicBroadcastRule(PinotRuleUtils.PINOT_REL_FACTORY);

public PinotJoinToDynamicBroadcastRule(RelBuilderFactory factory) {
super(operand(LogicalJoin.class, any()), factory, null);
super(operand(Join.class, any()), factory, null);
}

@Override
public boolean matches(RelOptRuleCall call) {
if (call.rels.length < 1 || !(call.rel(0) instanceof Join)) {
return false;
}
Join join = call.rel(0);
String joinStrategyString = PinotHintStrategyTable.getHintOption(join.getHints(),
PinotHintOptions.JOIN_HINT_OPTIONS, PinotHintOptions.JoinHintOptions.JOIN_STRATEGY);
List<String> joinStrategies = joinStrategyString != null ? StringUtils.split(joinStrategyString, ",")
: Collections.emptyList();
boolean explicitOtherStrategy = joinStrategies.size() > 0
&& !joinStrategies.contains(PinotHintOptions.JoinHintOptions.DYNAMIC_BROADCAST_JOIN_STRATEGY);
String joinStrategyString =
PinotHintStrategyTable.getHintOption(join.getHints(), PinotHintOptions.JOIN_HINT_OPTIONS,
PinotHintOptions.JoinHintOptions.JOIN_STRATEGY);
List<String> joinStrategies =
joinStrategyString != null ? StringUtils.split(joinStrategyString, ",") : Collections.emptyList();
boolean explicitOtherStrategy = !joinStrategies.isEmpty() && !joinStrategies.contains(
PinotHintOptions.JoinHintOptions.DYNAMIC_BROADCAST_JOIN_STRATEGY);

JoinInfo joinInfo = join.analyzeCondition();
RelNode left = join.getLeft() instanceof HepRelVertex ? ((HepRelVertex) join.getLeft()).getCurrentRel()
: join.getLeft();
RelNode right = join.getRight() instanceof HepRelVertex ? ((HepRelVertex) join.getRight()).getCurrentRel()
: join.getRight();
RelNode left = ((HepRelVertex) join.getLeft()).getCurrentRel();
RelNode right = ((HepRelVertex) join.getRight()).getCurrentRel();
return left instanceof Exchange && right instanceof Exchange
// left side can be pushed as dynamic exchange
&& PinotRuleUtils.canPushDynamicBroadcastToLeaf(left.getInput(0))
Expand All @@ -155,16 +149,15 @@ public boolean matches(RelOptRuleCall call) {
@Override
public void onMatch(RelOptRuleCall call) {
Join join = call.rel(0);
PinotLogicalExchange left = (PinotLogicalExchange) (join.getLeft() instanceof HepRelVertex
? ((HepRelVertex) join.getLeft()).getCurrentRel() : join.getLeft());
PinotLogicalExchange right = (PinotLogicalExchange) (join.getRight() instanceof HepRelVertex
? ((HepRelVertex) join.getRight()).getCurrentRel() : join.getRight());
Exchange left = (Exchange) ((HepRelVertex) join.getLeft()).getCurrentRel();
Exchange right = (Exchange) ((HepRelVertex) join.getRight()).getCurrentRel();

// when colocated join hint is given, dynamic broadcast exchange can be hash-distributed b/c
// 1. currently, dynamic broadcast only works against main table off leaf-stage; (e.g. receive node on leaf)
// 2. when hash key are the same but hash functions are different, it can be done via normal hash shuffle.
boolean isColocatedJoin = PinotHintStrategyTable.isHintOptionTrue(join.getHints(),
PinotHintOptions.JOIN_HINT_OPTIONS, PinotHintOptions.JoinHintOptions.IS_COLOCATED_BY_JOIN_KEYS);
boolean isColocatedJoin =
PinotHintStrategyTable.isHintOptionTrue(join.getHints(), PinotHintOptions.JOIN_HINT_OPTIONS,
PinotHintOptions.JoinHintOptions.IS_COLOCATED_BY_JOIN_KEYS);
PinotLogicalExchange dynamicBroadcastExchange;
RelNode rightInput = right.getInput();
if (isColocatedJoin) {
Expand All @@ -174,10 +167,8 @@ public void onMatch(RelOptRuleCall call) {
RelDistribution dist = RelDistributions.BROADCAST_DISTRIBUTED;
dynamicBroadcastExchange = PinotLogicalExchange.create(rightInput, dist, PinotRelExchangeType.PIPELINE_BREAKER);
}
Join dynamicFilterJoin =
new LogicalJoin(join.getCluster(), join.getTraitSet(), left.getInput(), dynamicBroadcastExchange,
join.getCondition(), join.getVariablesSet(), join.getJoinType(), join.isSemiJoinDone(),
ImmutableList.copyOf(join.getSystemFieldList()));
call.transformTo(dynamicFilterJoin);

call.transformTo(join.copy(join.getTraitSet(), join.getCondition(), left.getInput(), dynamicBroadcastExchange,
join.getJoinType(), join.isSemiJoinDone()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,6 @@ public PinotRelDistributionTraitRule(RelBuilderFactory factory) {
super(operand(RelNode.class, any()), factory, null);
}

@Override
public boolean matches(RelOptRuleCall call) {
return call.rels.length >= 1;
}

@Override
public void onMatch(RelOptRuleCall call) {
RelNode current = call.rel(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,13 @@

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.rel.RelDistributions;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.SetOp;
import org.apache.calcite.rel.logical.LogicalIntersect;
import org.apache.calcite.rel.logical.LogicalMinus;
import org.apache.calcite.rel.logical.LogicalUnion;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.ImmutableIntList;
import org.apache.pinot.calcite.rel.logical.PinotLogicalExchange;


Expand All @@ -47,41 +43,20 @@ public PinotSetOpExchangeNodeInsertRule(RelBuilderFactory factory) {

@Override
public boolean matches(RelOptRuleCall call) {
if (call.rels.length < 1) {
return false;
}
if (call.rel(0) instanceof SetOp) {
SetOp setOp = call.rel(0);
for (RelNode input : setOp.getInputs()) {
if (PinotRuleUtils.isExchange(input)) {
return false;
}
}
return true;
}
return false;
SetOp setOp = call.rel(0);
return !PinotRuleUtils.isExchange(setOp.getInput(0));
}

@Override
public void onMatch(RelOptRuleCall call) {
SetOp setOp = call.rel(0);
List<RelNode> newInputs = new ArrayList<>();
List<Integer> hashFields =
IntStream.range(0, setOp.getRowType().getFieldCount()).boxed().collect(Collectors.toCollection(ArrayList::new));
for (RelNode input : setOp.getInputs()) {
RelNode exchange = PinotLogicalExchange.create(input, RelDistributions.hash(hashFields));
List<RelNode> inputs = setOp.getInputs();
List<RelNode> newInputs = new ArrayList<>(inputs.size());
for (RelNode input : inputs) {
RelNode exchange = PinotLogicalExchange.create(input,
RelDistributions.hash(ImmutableIntList.range(0, setOp.getRowType().getFieldCount())));
newInputs.add(exchange);
}
SetOp newSetOpNode;
if (setOp instanceof LogicalUnion) {
newSetOpNode = new LogicalUnion(setOp.getCluster(), setOp.getTraitSet(), newInputs, setOp.all);
} else if (setOp instanceof LogicalIntersect) {
newSetOpNode = new LogicalIntersect(setOp.getCluster(), setOp.getTraitSet(), newInputs, setOp.all);
} else if (setOp instanceof LogicalMinus) {
newSetOpNode = new LogicalMinus(setOp.getCluster(), setOp.getTraitSet(), newInputs, setOp.all);
} else {
throw new UnsupportedOperationException("Unsupported set op node: " + setOp);
}
call.transformTo(newSetOpNode);
call.transformTo(setOp.copy(setOp.getTraitSet(), newInputs));
}
}
Loading
Loading