Skip to content

Commit

Permalink
feat: projection merge and projection filter transpose rule (#162)
Browse files Browse the repository at this point in the history
This PR implements a part of the projection transpose series of rules.
It also includes a fair amount of refactoring.

### Projection Merge Rule
- This rule matches on two projection nodes and combines the two nodes
into one.
- It is added to the heuristic optimizer pass before the cascades
optimizer. In the future, it should also be added to a pass after the
cascades optimizer.

### Projection Filter Transpose Rule
- This rule matches pushes a projection node passed a filter node. If
the filter node contains columns that are not in this projection node,
the top most projection node is also kept.
- It is added as a cascades rule.

### Refactoring
Relevant functions for projection transpose rules can be found in
`project_transpose_common.rs`. Rules are implemented in separate files
as a part of the `projection_transpose` module rather than in all one
file. Similarly, `FilterProjectTransposeRule` and `ProjectionPullUpJoin`
were moved into this module.

### Testing
Unit tests using the dummy heuristic optimizer were implemented.

---------

Signed-off-by: AveryQi115 <averyqi115@gmail.com>
Co-authored-by: Benjamin O <jeep70cp@gmail.com>
Co-authored-by: AveryQi115 <averyqi115@gmail.com>
  • Loading branch information
3 people authored May 1, 2024
1 parent 74dc3ff commit 9ef6339
Show file tree
Hide file tree
Showing 16 changed files with 948 additions and 270 deletions.
10 changes: 8 additions & 2 deletions optd-datafusion-repr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ use rules::{
EliminateJoinRule, EliminateLimitRule, FilterAggTransposeRule, FilterCrossJoinTransposeRule,
FilterInnerJoinTransposeRule, FilterMergeRule, FilterProjectTransposeRule,
FilterSortTransposeRule, HashJoinRule, JoinAssocRule, JoinCommuteRule, PhysicalConversionRule,
ProjectionPullUpJoin, SimplifyFilterRule, SimplifyJoinCondRule,
ProjectFilterTransposeRule, ProjectMergeRule, ProjectionPullUpJoin, SimplifyFilterRule,
SimplifyJoinCondRule,
};

pub use optd_core::rel_node::Value;
Expand Down Expand Up @@ -87,6 +88,8 @@ impl DatafusionOptimizer {
Arc::new(EliminateLimitRule::new()),
Arc::new(EliminateDuplicatedSortExprRule::new()),
Arc::new(EliminateDuplicatedAggExprRule::new()),
Arc::new(ProjectMergeRule::new()),
Arc::new(FilterMergeRule::new()),
]
}

Expand All @@ -97,11 +100,14 @@ impl DatafusionOptimizer {
for rule in rules {
rule_wrappers.push(RuleWrapper::new_cascades(rule));
}
// project transpose rules
rule_wrappers.push(RuleWrapper::new_cascades(Arc::new(
ProjectFilterTransposeRule::new(),
)));
// add all filter pushdown rules as heuristic rules
rule_wrappers.push(RuleWrapper::new_heuristic(Arc::new(
FilterProjectTransposeRule::new(),
)));
rule_wrappers.push(RuleWrapper::new_heuristic(Arc::new(FilterMergeRule::new())));
rule_wrappers.push(RuleWrapper::new_heuristic(Arc::new(
FilterCrossJoinTransposeRule::new(),
)));
Expand Down
24 changes: 24 additions & 0 deletions optd-datafusion-repr/src/plan_nodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,30 @@ impl Expr {
.unwrap(),
)
}

/// Recursively retrieves all column references in the expression
/// using a provided function.
/// The provided function will, given a ColumnRefExpr's index,
/// return a Vec<Expr> including the expr in col ref.
pub fn get_column_refs(&self) -> Vec<Expr> {
assert!(self.typ().is_expression());
if let OptRelNodeTyp::ColumnRef = self.typ() {
let col_ref = Expr::from_rel_node(self.0.clone()).unwrap();
return vec![col_ref];
}

let children = self.0.children.clone();
let children = children.into_iter().map(|child| {
if child.typ == OptRelNodeTyp::List {
// TODO: What should we do with List?
return vec![];
}
Expr::from_rel_node(child.clone())
.unwrap()
.get_column_refs()
});
children.collect_vec().concat()
}
}

impl OptRelNode for Expr {
Expand Down
73 changes: 1 addition & 72 deletions optd-datafusion-repr/src/plan_nodes/projection.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::expr::ExprList;
use super::macros::define_plan_node;

use super::{ColumnRefExpr, Expr, OptRelNode, OptRelNodeRef, OptRelNodeTyp, PlanNode};
use super::{OptRelNode, OptRelNodeRef, OptRelNodeTyp, PlanNode};

#[derive(Clone, Debug)]
pub struct LogicalProjection(pub PlanNode);
Expand All @@ -26,74 +26,3 @@ define_plan_node!(
{ 1, exprs: ExprList }
]
);

/// This struct holds the mapping from original columns to projected columns.
///
/// # Example
/// With the following plan:
/// | Filter (#0 < 5)
/// |
/// |-| Projection [#2, #3]
/// |- Scan [#0, #1, #2, #3]
///
/// The computed projection mapping is:
/// #2 -> #0
/// #3 -> #1
pub struct ProjectionMapping {
forward: Vec<usize>,
_backward: Vec<Option<usize>>,
}

impl ProjectionMapping {
pub fn build(mapping: Vec<usize>) -> Option<Self> {
let mut backward = vec![];
for (i, &x) in mapping.iter().enumerate() {
if x >= backward.len() {
backward.resize(x + 1, None);
}
backward[x] = Some(i);
}
Some(Self {
forward: mapping,
_backward: backward,
})
}

pub fn projection_col_refers_to(&self, col: usize) -> usize {
self.forward[col]
}

pub fn _original_col_maps_to(&self, col: usize) -> Option<usize> {
self._backward[col]
}

/// Recursively rewrites all ColumnRefs in an Expr to *undo* the projection
/// condition. You might want to do this if you are pushing something
/// through a projection, or pulling a projection up.
///
/// # Example
/// If we have a projection node, mapping column A to column B (A -> B)
/// All B's in `cond` will be rewritten as A.
pub fn rewrite_condition(&self, cond: Expr, child_schema_len: usize) -> Expr {
let proj_schema_size = self.forward.len();
cond.rewrite_column_refs(&|idx| {
Some(if idx < proj_schema_size {
self.projection_col_refers_to(idx)
} else {
idx - proj_schema_size + child_schema_len
})
})
.unwrap()
}
}

impl LogicalProjection {
pub fn compute_column_mapping(exprs: &ExprList) -> Option<ProjectionMapping> {
let mut mapping = vec![];
for expr in exprs.to_vec() {
let col_expr = ColumnRefExpr::from_rel_node(expr.into_rel_node())?;
mapping.push(col_expr.index());
}
ProjectionMapping::build(mapping)
}
}
12 changes: 8 additions & 4 deletions optd-datafusion-repr/src/rules.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ mod filter_pushdown;
mod joins;
mod macros;
mod physical;
mod project_transpose;

// pub use filter_join::FilterJoinPullUpRule;
pub use eliminate_duplicated_expr::{
Expand All @@ -15,9 +16,12 @@ pub use eliminate_limit::EliminateLimitRule;
pub use filter::{EliminateFilterRule, SimplifyFilterRule, SimplifyJoinCondRule};
pub use filter_pushdown::{
FilterAggTransposeRule, FilterCrossJoinTransposeRule, FilterInnerJoinTransposeRule,
FilterMergeRule, FilterProjectTransposeRule, FilterSortTransposeRule,
};
pub use joins::{
EliminateJoinRule, HashJoinRule, JoinAssocRule, JoinCommuteRule, ProjectionPullUpJoin,
FilterMergeRule, FilterSortTransposeRule,
};
pub use joins::{EliminateJoinRule, HashJoinRule, JoinAssocRule, JoinCommuteRule};
pub use physical::PhysicalConversionRule;
pub use project_transpose::{
project_filter_transpose::{FilterProjectTransposeRule, ProjectFilterTransposeRule},
project_join_transpose::ProjectionPullUpJoin,
project_merge::ProjectMergeRule,
};
116 changes: 4 additions & 112 deletions optd-datafusion-repr/src/rules/filter_pushdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use optd_core::{optimizer::Optimizer, rel_node::RelNode};

use crate::plan_nodes::{
ColumnRefExpr, Expr, ExprList, JoinType, LogOpExpr, LogOpType, LogicalAgg, LogicalFilter,
LogicalJoin, LogicalProjection, LogicalSort, OptRelNode, OptRelNodeTyp, PlanNode,
LogicalJoin, LogicalSort, OptRelNode, OptRelNodeTyp, PlanNode,
};
use crate::properties::schema::SchemaPropertyBuilder;

Expand Down Expand Up @@ -122,36 +122,6 @@ fn categorize_conds(mut categorization_fn: impl FnMut(Expr, &Vec<Expr>), cond: E
}
}

define_rule!(
FilterProjectTransposeRule,
apply_filter_project_transpose,
(Filter, (Projection, child, [exprs]), [cond])
);

/// Datafusion only pushes filter past project when the project does not contain
/// volatile (i.e. non-deterministic) expressions that are present in the filter
/// Calcite only checks if the projection contains a windowing calculation
/// We check neither of those things and do it always (which may be wrong)
fn apply_filter_project_transpose(
optimizer: &impl Optimizer<OptRelNodeTyp>,
FilterProjectTransposeRulePicks { child, exprs, cond }: FilterProjectTransposeRulePicks,
) -> Vec<RelNode<OptRelNodeTyp>> {
let child_schema_len = optimizer
.get_property::<SchemaPropertyBuilder>(child.clone().into(), 0)
.len();

let child = PlanNode::from_group(child.into());
let cond_as_expr = Expr::from_rel_node(cond.into()).unwrap();
let exprs = ExprList::from_rel_node(exprs.into()).unwrap();

let proj_col_map = LogicalProjection::compute_column_mapping(&exprs).unwrap();
let rewritten_cond = proj_col_map.rewrite_condition(cond_as_expr.clone(), child_schema_len);

let new_filter_node = LogicalFilter::new(child, rewritten_cond);
let new_proj = LogicalProjection::new(new_filter_node.into_plan_node(), exprs);
vec![new_proj.into_rel_node().as_ref().clone()]
}

define_rule!(
FilterMergeRule,
apply_filter_merge,
Expand Down Expand Up @@ -451,12 +421,12 @@ mod tests {
use crate::{
plan_nodes::{
BinOpExpr, BinOpType, ColumnRefExpr, ConstantExpr, ExprList, LogOpExpr, LogOpType,
LogicalAgg, LogicalFilter, LogicalJoin, LogicalProjection, LogicalScan, LogicalSort,
OptRelNode, OptRelNodeTyp,
LogicalAgg, LogicalFilter, LogicalJoin, LogicalScan, LogicalSort, OptRelNode,
OptRelNodeTyp,
},
rules::{
FilterAggTransposeRule, FilterInnerJoinTransposeRule, FilterMergeRule,
FilterProjectTransposeRule, FilterSortTransposeRule,
FilterSortTransposeRule,
},
testing::new_test_optimizer,
};
Expand Down Expand Up @@ -538,84 +508,6 @@ mod tests {
assert_eq!(col_4.value().as_i32(), 1);
}

#[test]
fn push_past_proj_basic() {
let mut test_optimizer = new_test_optimizer(Arc::new(FilterProjectTransposeRule::new()));

let scan = LogicalScan::new("customer".into());
let proj = LogicalProjection::new(scan.into_plan_node(), ExprList::new(vec![]));

let filter_expr = BinOpExpr::new(
ColumnRefExpr::new(0).into_expr(),
ConstantExpr::int32(5).into_expr(),
BinOpType::Eq,
)
.into_expr();

let filter = LogicalFilter::new(proj.into_plan_node(), filter_expr);
let plan = test_optimizer.optimize(filter.into_rel_node()).unwrap();

assert_eq!(plan.typ, OptRelNodeTyp::Projection);
assert!(matches!(plan.child(0).typ, OptRelNodeTyp::Filter));
}

#[test]
fn push_past_proj_adv() {
let mut test_optimizer = new_test_optimizer(Arc::new(FilterProjectTransposeRule::new()));

let scan = LogicalScan::new("customer".into());
let proj = LogicalProjection::new(
scan.into_plan_node(),
ExprList::new(vec![
ColumnRefExpr::new(0).into_expr(),
ColumnRefExpr::new(4).into_expr(),
ColumnRefExpr::new(5).into_expr(),
ColumnRefExpr::new(7).into_expr(),
]),
);

let filter_expr = LogOpExpr::new(
LogOpType::And,
ExprList::new(vec![
BinOpExpr::new(
// This one should be pushed to the left child
ColumnRefExpr::new(1).into_expr(),
ConstantExpr::int32(5).into_expr(),
BinOpType::Eq,
)
.into_expr(),
BinOpExpr::new(
// This one should be pushed to the right child
ColumnRefExpr::new(3).into_expr(),
ConstantExpr::int32(6).into_expr(),
BinOpType::Eq,
)
.into_expr(),
]),
);

let filter = LogicalFilter::new(proj.into_plan_node(), filter_expr.into_expr());

let plan = test_optimizer.optimize(filter.into_rel_node()).unwrap();

assert!(matches!(plan.typ, OptRelNodeTyp::Projection));
let plan_filter = LogicalFilter::from_rel_node(plan.child(0)).unwrap();
assert!(matches!(plan_filter.0.typ(), OptRelNodeTyp::Filter));
let plan_filter_expr =
LogOpExpr::from_rel_node(plan_filter.cond().into_rel_node()).unwrap();
assert!(matches!(plan_filter_expr.op_type(), LogOpType::And));
let op_0 = BinOpExpr::from_rel_node(plan_filter_expr.children()[0].clone().into_rel_node())
.unwrap();
let col_0 =
ColumnRefExpr::from_rel_node(op_0.left_child().clone().into_rel_node()).unwrap();
assert_eq!(col_0.index(), 4);
let op_1 = BinOpExpr::from_rel_node(plan_filter_expr.children()[1].clone().into_rel_node())
.unwrap();
let col_1 =
ColumnRefExpr::from_rel_node(op_1.left_child().clone().into_rel_node()).unwrap();
assert_eq!(col_1.index(), 7);
}

#[test]
fn push_past_join_conjunction() {
// Test pushing a complex filter past a join, where one clause can
Expand Down
56 changes: 0 additions & 56 deletions optd-datafusion-repr/src/rules/joins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -309,59 +309,3 @@ fn apply_hash_join(
}
vec![]
}

// (Proj A) join B -> (Proj (A join B))
define_rule!(
ProjectionPullUpJoin,
apply_projection_pull_up_join,
(
Join(JoinType::Inner),
(Projection, left, [list]),
right,
[cond]
)
);

fn apply_projection_pull_up_join(
optimizer: &impl Optimizer<OptRelNodeTyp>,
ProjectionPullUpJoinPicks {
left,
right,
list,
cond,
}: ProjectionPullUpJoinPicks,
) -> Vec<RelNode<OptRelNodeTyp>> {
let left = Arc::new(left.clone());
let right = Arc::new(right.clone());

let list = ExprList::from_rel_node(Arc::new(list)).unwrap();

let projection = LogicalProjection::new(PlanNode::from_group(left.clone()), list.clone());

let Some(mapping) = LogicalProjection::compute_column_mapping(&projection.exprs()) else {
return vec![];
};

// TODO(chi): support capture projection node.
let left_schema = optimizer.get_property::<SchemaPropertyBuilder>(left.clone(), 0);
let right_schema = optimizer.get_property::<SchemaPropertyBuilder>(right.clone(), 0);
let mut new_projection_exprs = list.to_vec();
for i in 0..right_schema.len() {
let col: Expr = ColumnRefExpr::new(i + left_schema.len()).into_expr();
new_projection_exprs.push(col);
}
let node = LogicalProjection::new(
LogicalJoin::new(
PlanNode::from_group(left),
PlanNode::from_group(right),
mapping.rewrite_condition(
Expr::from_rel_node(Arc::new(cond)).unwrap(),
left_schema.len(),
),
JoinType::Inner,
)
.into_plan_node(),
ExprList::new(new_projection_exprs),
);
vec![node.into_rel_node().as_ref().clone()]
}
4 changes: 4 additions & 0 deletions optd-datafusion-repr/src/rules/project_transpose.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
pub mod project_filter_transpose;
pub mod project_join_transpose;
pub mod project_merge;
pub mod project_transpose_common;
Loading

0 comments on commit 9ef6339

Please sign in to comment.