use std::sync::Arc;
use crate::{OptimizerConfig, OptimizerRule};
use crate::join_key_set::JoinKeySet;
use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_common::{internal_err, Result};
use datafusion_expr::expr::{BinaryExpr, Expr};
use datafusion_expr::logical_plan::tree_node::unwrap_arc;
use datafusion_expr::logical_plan::{
CrossJoin, Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection,
};
use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair};
use datafusion_expr::{build_join_schema, ExprSchemable, Operator};
#[derive(Default)]
pub struct EliminateCrossJoin;
impl EliminateCrossJoin {
#[allow(missing_docs)]
pub fn new() -> Self {
Self {}
}
}
impl OptimizerRule for EliminateCrossJoin {
fn supports_rewrite(&self) -> bool {
true
}
fn rewrite(
&self,
plan: LogicalPlan,
config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>> {
let plan_schema = Arc::clone(plan.schema());
let mut possible_join_keys = JoinKeySet::new();
let mut all_inputs: Vec<LogicalPlan> = vec![];
let parent_predicate = if let LogicalPlan::Filter(filter) = plan {
let rewriteable = matches!(
filter.input.as_ref(),
LogicalPlan::Join(Join {
join_type: JoinType::Inner,
..
}) | LogicalPlan::CrossJoin(_)
);
if !rewriteable {
return rewrite_children(self, LogicalPlan::Filter(filter), config);
}
if !can_flatten_join_inputs(&filter.input) {
return Ok(Transformed::no(LogicalPlan::Filter(filter)));
}
let Filter {
input, predicate, ..
} = filter;
flatten_join_inputs(
unwrap_arc(input),
&mut possible_join_keys,
&mut all_inputs,
)?;
extract_possible_join_keys(&predicate, &mut possible_join_keys);
Some(predicate)
} else if matches!(
plan,
LogicalPlan::Join(Join {
join_type: JoinType::Inner,
..
})
) {
if !can_flatten_join_inputs(&plan) {
return Ok(Transformed::no(plan));
}
flatten_join_inputs(plan, &mut possible_join_keys, &mut all_inputs)?;
None
} else {
return rewrite_children(self, plan, config);
};
let mut all_join_keys = JoinKeySet::new();
let mut left = all_inputs.remove(0);
while !all_inputs.is_empty() {
left = find_inner_join(
left,
&mut all_inputs,
&possible_join_keys,
&mut all_join_keys,
)?;
}
left = rewrite_children(self, left, config)?.data;
if &plan_schema != left.schema() {
left = LogicalPlan::Projection(Projection::new_from_schema(
Arc::new(left),
Arc::clone(&plan_schema),
));
}
let Some(predicate) = parent_predicate else {
return Ok(Transformed::yes(left));
};
if all_join_keys.is_empty() {
Filter::try_new(predicate, Arc::new(left))
.map(|filter| Transformed::yes(LogicalPlan::Filter(filter)))
} else {
match remove_join_expressions(predicate, &all_join_keys) {
Some(filter_expr) => Filter::try_new(filter_expr, Arc::new(left))
.map(|filter| Transformed::yes(LogicalPlan::Filter(filter))),
_ => Ok(Transformed::yes(left)),
}
}
}
fn name(&self) -> &str {
"eliminate_cross_join"
}
}
fn rewrite_children(
optimizer: &impl OptimizerRule,
plan: LogicalPlan,
config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>> {
let transformed_plan = plan.map_children(|input| optimizer.rewrite(input, config))?;
if transformed_plan.transformed {
transformed_plan.map_data(|plan| plan.recompute_schema())
} else {
Ok(transformed_plan)
}
}
fn flatten_join_inputs(
plan: LogicalPlan,
possible_join_keys: &mut JoinKeySet,
all_inputs: &mut Vec<LogicalPlan>,
) -> Result<()> {
match plan {
LogicalPlan::Join(join) if join.join_type == JoinType::Inner => {
if join.filter.is_some() {
return internal_err!(
"should not have filter in inner join in flatten_join_inputs"
);
}
possible_join_keys.insert_all_owned(join.on);
flatten_join_inputs(unwrap_arc(join.left), possible_join_keys, all_inputs)?;
flatten_join_inputs(unwrap_arc(join.right), possible_join_keys, all_inputs)?;
}
LogicalPlan::CrossJoin(join) => {
flatten_join_inputs(unwrap_arc(join.left), possible_join_keys, all_inputs)?;
flatten_join_inputs(unwrap_arc(join.right), possible_join_keys, all_inputs)?;
}
_ => {
all_inputs.push(plan);
}
};
Ok(())
}
fn can_flatten_join_inputs(plan: &LogicalPlan) -> bool {
match plan {
LogicalPlan::Join(join) if join.join_type == JoinType::Inner => {
if join.filter.is_some() {
return false;
}
}
LogicalPlan::CrossJoin(_) => {}
_ => return false,
};
for child in plan.inputs() {
match child {
LogicalPlan::Join(Join {
join_type: JoinType::Inner,
..
})
| LogicalPlan::CrossJoin(_) => {
if !can_flatten_join_inputs(child) {
return false;
}
}
_ => (),
}
}
true
}
fn find_inner_join(
left_input: LogicalPlan,
rights: &mut Vec<LogicalPlan>,
possible_join_keys: &JoinKeySet,
all_join_keys: &mut JoinKeySet,
) -> Result<LogicalPlan> {
for (i, right_input) in rights.iter().enumerate() {
let mut join_keys = vec![];
for (l, r) in possible_join_keys.iter() {
let key_pair = find_valid_equijoin_key_pair(
l,
r,
left_input.schema(),
right_input.schema(),
)?;
if let Some((valid_l, valid_r)) = key_pair {
if can_hash(&valid_l.get_type(left_input.schema())?) {
join_keys.push((valid_l, valid_r));
}
}
}
if !join_keys.is_empty() {
all_join_keys.insert_all(join_keys.iter());
let right_input = rights.remove(i);
let join_schema = Arc::new(build_join_schema(
left_input.schema(),
right_input.schema(),
&JoinType::Inner,
)?);
return Ok(LogicalPlan::Join(Join {
left: Arc::new(left_input),
right: Arc::new(right_input),
join_type: JoinType::Inner,
join_constraint: JoinConstraint::On,
on: join_keys,
filter: None,
schema: join_schema,
null_equals_null: false,
}));
}
}
let right = rights.remove(0);
let join_schema = Arc::new(build_join_schema(
left_input.schema(),
right.schema(),
&JoinType::Inner,
)?);
Ok(LogicalPlan::CrossJoin(CrossJoin {
left: Arc::new(left_input),
right: Arc::new(right),
schema: join_schema,
}))
}
fn extract_possible_join_keys(expr: &Expr, join_keys: &mut JoinKeySet) {
if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr {
match op {
Operator::Eq => {
join_keys.insert(left, right);
}
Operator::And => {
extract_possible_join_keys(left, join_keys);
extract_possible_join_keys(right, join_keys)
}
Operator::Or => {
let mut left_join_keys = JoinKeySet::new();
let mut right_join_keys = JoinKeySet::new();
extract_possible_join_keys(left, &mut left_join_keys);
extract_possible_join_keys(right, &mut right_join_keys);
join_keys.insert_intersection(left_join_keys, right_join_keys)
}
_ => (),
};
}
}
fn remove_join_expressions(expr: Expr, join_keys: &JoinKeySet) -> Option<Expr> {
match expr {
Expr::BinaryExpr(BinaryExpr {
left,
op: Operator::Eq,
right,
}) if join_keys.contains(&left, &right) => {
None
}
Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == Operator::And => {
let l = remove_join_expressions(*left, join_keys);
let r = remove_join_expressions(*right, join_keys);
match (l, r) {
(Some(ll), Some(rr)) => Some(Expr::BinaryExpr(BinaryExpr::new(
Box::new(ll),
op,
Box::new(rr),
))),
(Some(ll), _) => Some(ll),
(_, Some(rr)) => Some(rr),
_ => None,
}
}
Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == Operator::Or => {
let l = remove_join_expressions(*left, join_keys);
let r = remove_join_expressions(*right, join_keys);
match (l, r) {
(Some(ll), Some(rr)) => Some(Expr::BinaryExpr(BinaryExpr::new(
Box::new(ll),
op,
Box::new(rr),
))),
_ => None,
}
}
_ => Some(expr),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::optimizer::OptimizerContext;
use crate::test::*;
use datafusion_expr::{
binary_expr, col, lit,
logical_plan::builder::LogicalPlanBuilder,
Operator::{And, Or},
};
fn assert_optimized_plan_eq(plan: LogicalPlan, expected: Vec<&str>) {
let starting_schema = Arc::clone(plan.schema());
let rule = EliminateCrossJoin::new();
let transformed_plan = rule.rewrite(plan, &OptimizerContext::new()).unwrap();
assert!(transformed_plan.transformed, "failed to optimize plan");
let optimized_plan = transformed_plan.data;
let formatted = optimized_plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
assert_eq!(
expected, actual,
"\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
);
assert_eq!(&starting_schema, optimized_plan.schema())
}
fn assert_optimization_rule_fails(plan: LogicalPlan) {
let rule = EliminateCrossJoin::new();
let transformed_plan = rule.rewrite(plan, &OptimizerContext::new()).unwrap();
assert!(!transformed_plan.transformed)
}
#[test]
fn eliminate_cross_with_simple_and() -> Result<()> {
let t1 = test_table_scan_with_name("t1")?;
let t2 = test_table_scan_with_name("t2")?;
let plan = LogicalPlanBuilder::from(t1)
.cross_join(t2)?
.filter(binary_expr(
col("t1.a").eq(col("t2.a")),
And,
col("t2.c").lt(lit(20u32)),
))?
.build()?;
let expected = vec![
"Filter: t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
];
assert_optimized_plan_eq(plan, expected);
Ok(())
}
#[test]
fn eliminate_cross_with_simple_or() -> Result<()> {
let t1 = test_table_scan_with_name("t1")?;
let t2 = test_table_scan_with_name("t2")?;
let plan = LogicalPlanBuilder::from(t1)
.cross_join(t2)?
.filter(binary_expr(
col("t1.a").eq(col("t2.a")),
Or,
col("t2.b").eq(col("t1.a")),
))?
.build()?;
let expected = vec![
"Filter: t1.a = t2.a OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
];
assert_optimized_plan_eq(plan, expected);
Ok(())
}
#[test]
fn eliminate_cross_with_and() -> Result<()> {
let t1 = test_table_scan_with_name("t1")?;
let t2 = test_table_scan_with_name("t2")?;
let plan = LogicalPlanBuilder::from(t1)
.cross_join(t2)?
.filter(binary_expr(
binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(20u32))),
And,
binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").eq(lit(10u32))),
))?
.build()?;
let expected = vec![
"Filter: t2.c < UInt32(20) AND t2.c = UInt32(10) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
];
assert_optimized_plan_eq(plan, expected);
Ok(())
}
#[test]
fn eliminate_cross_with_or() -> Result<()> {
let t1 = test_table_scan_with_name("t1")?;
let t2 = test_table_scan_with_name("t2")?;
let plan = LogicalPlanBuilder::from(t1)
.cross_join(t2)?
.filter(binary_expr(
binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
Or,
binary_expr(
col("t1.a").eq(col("t2.a")),
And,
col("t2.c").eq(lit(688u32)),
),
))?
.build()?;
let expected = vec![
"Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
];
assert_optimized_plan_eq(plan, expected);
Ok(())
}
#[test]
fn eliminate_cross_not_possible_simple() -> Result<()> {
let t1 = test_table_scan_with_name("t1")?;
let t2 = test_table_scan_with_name("t2")?;
let plan = LogicalPlanBuilder::from(t1)
.cross_join(t2)?
.filter(binary_expr(
binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
Or,
binary_expr(
col("t1.b").eq(col("t2.b")),
And,
col("t2.c").eq(lit(688u32)),
),
))?
.build()?;
let expected = vec![
"Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.b = t2.b AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
];
assert_optimized_plan_eq(plan, expected);
Ok(())
}
#[test]
fn eliminate_cross_not_possible() -> Result<()> {
let t1 = test_table_scan_with_name("t1")?;
let t2 = test_table_scan_with_name("t2")?;
let plan = LogicalPlanBuilder::from(t1)
.cross_join(t2)?
.filter(binary_expr(
binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
Or,
binary_expr(col("t1.a").eq(col("t2.a")), Or, col("t2.c").eq(lit(688u32))),
))?
.build()?;
let expected = vec![
"Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.a = t2.a OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
];
assert_optimized_plan_eq(plan, expected);
Ok(())
}
#[test]
fn eliminate_cross_not_possible_nested_inner_join_with_filter() -> Result<()> {
let t1 = test_table_scan_with_name("t1")?;
let t2 = test_table_scan_with_name("t2")?;
let t3 = test_table_scan_with_name("t3")?;
let plan = LogicalPlanBuilder::from(t1)
.join(
t3,
JoinType::Inner,
(vec!["t1.a"], vec!["t3.a"]),
Some(col("t1.a").gt(lit(20u32))),
)?
.join(t2, JoinType::Inner, (vec!["t1.a"], vec!["t2.a"]), None)?
.filter(col("t1.a").gt(lit(15u32)))?
.build()?;
assert_optimization_rule_fails(plan);
Ok(())
}
#[test]
fn reorder_join_to_eliminate_cross_join_multi_tables() -> Result<()> {
let t1 = test_table_scan_with_name("t1")?;
let t2 = test_table_scan_with_name("t2")?;
let t3 = test_table_scan_with_name("t3")?;
let plan = LogicalPlanBuilder::from(t1)
.cross_join(t2)?
.cross_join(t3)?
.filter(binary_expr(
binary_expr(col("t3.a").eq(col("t1.a")), And, col("t3.c").lt(lit(15u32))),
And,
binary_expr(col("t3.a").eq(col("t2.a")), And, col("t3.b").lt(lit(15u32))),
))?
.build()?;
let expected = vec![
"Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Inner Join: t3.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
];
assert_optimized_plan_eq(plan, expected);
Ok(())
}
#[test]
fn eliminate_cross_join_multi_tables() -> Result<()> {
let t1 = test_table_scan_with_name("t1")?;
let t2 = test_table_scan_with_name("t2")?;
let t3 = test_table_scan_with_name("t3")?;
let t4 = test_table_scan_with_name("t4")?;
let plan1 = LogicalPlanBuilder::from(t1)
.cross_join(t2)?
.filter(binary_expr(
binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
Or,
binary_expr(
col("t1.a").eq(col("t2.a")),
And,
col("t2.c").eq(lit(688u32)),
),
))?
.build()?;
let plan2 = LogicalPlanBuilder::from(t3)
.cross_join(t4)?
.filter(binary_expr(
binary_expr(
binary_expr(
col("t3.a").eq(col("t4.a")),
And,
col("t4.c").lt(lit(15u32)),
),
Or,
binary_expr(
col("t3.a").eq(col("t4.a")),
And,
col("t3.c").eq(lit(688u32)),
),
),
Or,
binary_expr(
col("t3.a").eq(col("t4.a")),
And,
col("t3.b").eq(col("t4.b")),
),
))?
.build()?;
let plan = LogicalPlanBuilder::from(plan1)
.cross_join(plan2)?
.filter(binary_expr(
binary_expr(col("t3.a").eq(col("t1.a")), And, col("t4.c").lt(lit(15u32))),
Or,
binary_expr(
col("t3.a").eq(col("t1.a")),
And,
col("t4.c").eq(lit(688u32)),
),
))?
.build()?;
let expected = vec![
"Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
" Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]",
];
assert_optimized_plan_eq(plan, expected);
Ok(())
}
#[test]
fn eliminate_cross_join_multi_tables_1() -> Result<()> {
let t1 = test_table_scan_with_name("t1")?;
let t2 = test_table_scan_with_name("t2")?;
let t3 = test_table_scan_with_name("t3")?;
let t4 = test_table_scan_with_name("t4")?;
let plan1 = LogicalPlanBuilder::from(t1)
.cross_join(t2)?
.filter(binary_expr(
binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
Or,
binary_expr(
col("t1.a").eq(col("t2.a")),
And,
col("t2.c").eq(lit(688u32)),
),
))?
.build()?;
let plan2 = LogicalPlanBuilder::from(t3)
.cross_join(t4)?
.filter(binary_expr(
binary_expr(
binary_expr(
col("t3.a").eq(col("t4.a")),
And,
col("t4.c").lt(lit(15u32)),
),
Or,
binary_expr(
col("t3.a").eq(col("t4.a")),
And,
col("t3.c").eq(lit(688u32)),
),
),
Or,
binary_expr(
col("t3.a").eq(col("t4.a")),
And,
col("t3.b").eq(col("t4.b")),
),
))?
.build()?;
let plan = LogicalPlanBuilder::from(plan1)
.cross_join(plan2)?
.filter(binary_expr(
binary_expr(col("t3.a").eq(col("t1.a")), And, col("t4.c").lt(lit(15u32))),
Or,
binary_expr(col("t3.a").eq(col("t1.a")), Or, col("t4.c").eq(lit(688u32))),
))?
.build()?;
let expected = vec![
"Filter: t3.a = t1.a AND t4.c < UInt32(15) OR t3.a = t1.a OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
" Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]",
];
assert_optimized_plan_eq(plan, expected);
Ok(())
}
#[test]
fn eliminate_cross_join_multi_tables_2() -> Result<()> {
let t1 = test_table_scan_with_name("t1")?;
let t2 = test_table_scan_with_name("t2")?;
let t3 = test_table_scan_with_name("t3")?;
let t4 = test_table_scan_with_name("t4")?;
let plan1 = LogicalPlanBuilder::from(t1)
.cross_join(t2)?
.filter(binary_expr(
binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
Or,
binary_expr(
col("t1.a").eq(col("t2.a")),
And,
col("t2.c").eq(lit(688u32)),
),
))?
.build()?;
let plan2 = LogicalPlanBuilder::from(t3)
.cross_join(t4)?
.filter(binary_expr(
binary_expr(
binary_expr(
col("t3.a").eq(col("t4.a")),
And,
col("t4.c").lt(lit(15u32)),
),
Or,
binary_expr(
col("t3.a").eq(col("t4.a")),
And,
col("t3.c").eq(lit(688u32)),
),
),
Or,
binary_expr(col("t3.a").eq(col("t4.a")), Or, col("t3.b").eq(col("t4.b"))),
))?
.build()?;
let plan = LogicalPlanBuilder::from(plan1)
.cross_join(plan2)?
.filter(binary_expr(
binary_expr(col("t3.a").eq(col("t1.a")), And, col("t4.c").lt(lit(15u32))),
Or,
binary_expr(
col("t3.a").eq(col("t1.a")),
And,
col("t4.c").eq(lit(688u32)),
),
))?
.build()?;
let expected = vec![
"Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
" Filter: t3.a = t4.a AND t4.c < UInt32(15) OR t3.a = t4.a AND t3.c = UInt32(688) OR t3.a = t4.a OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]",
];
assert_optimized_plan_eq(plan, expected);
Ok(())
}
#[test]
fn eliminate_cross_join_multi_tables_3() -> Result<()> {
let t1 = test_table_scan_with_name("t1")?;
let t2 = test_table_scan_with_name("t2")?;
let t3 = test_table_scan_with_name("t3")?;
let t4 = test_table_scan_with_name("t4")?;
let plan1 = LogicalPlanBuilder::from(t1)
.cross_join(t2)?
.filter(binary_expr(
binary_expr(col("t1.a").eq(col("t2.a")), Or, col("t2.c").lt(lit(15u32))),
Or,
binary_expr(
col("t1.a").eq(col("t2.a")),
And,
col("t2.c").eq(lit(688u32)),
),
))?
.build()?;
let plan2 = LogicalPlanBuilder::from(t3)
.cross_join(t4)?
.filter(binary_expr(
binary_expr(
binary_expr(
col("t3.a").eq(col("t4.a")),
And,
col("t4.c").lt(lit(15u32)),
),
Or,
binary_expr(
col("t3.a").eq(col("t4.a")),
And,
col("t3.c").eq(lit(688u32)),
),
),
Or,
binary_expr(
col("t3.a").eq(col("t4.a")),
And,
col("t3.b").eq(col("t4.b")),
),
))?
.build()?;
let plan = LogicalPlanBuilder::from(plan1)
.cross_join(plan2)?
.filter(binary_expr(
binary_expr(col("t3.a").eq(col("t1.a")), And, col("t4.c").lt(lit(15u32))),
Or,
binary_expr(
col("t3.a").eq(col("t1.a")),
And,
col("t4.c").eq(lit(688u32)),
),
))?
.build()?;
let expected = vec![
"Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Filter: t1.a = t2.a OR t2.c < UInt32(15) OR t1.a = t2.a AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
" Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]",
];
assert_optimized_plan_eq(plan, expected);
Ok(())
}
#[test]
fn eliminate_cross_join_multi_tables_4() -> Result<()> {
let t1 = test_table_scan_with_name("t1")?;
let t2 = test_table_scan_with_name("t2")?;
let t3 = test_table_scan_with_name("t3")?;
let t4 = test_table_scan_with_name("t4")?;
let plan1 = LogicalPlanBuilder::from(t1)
.cross_join(t2)?
.filter(binary_expr(
binary_expr(col("t1.a").eq(col("t2.a")), Or, col("t2.c").lt(lit(15u32))),
And,
binary_expr(
col("t1.a").eq(col("t2.a")),
And,
col("t2.c").eq(lit(688u32)),
),
))?
.build()?;
let plan2 = LogicalPlanBuilder::from(t3).cross_join(t4)?.build()?;
let plan = LogicalPlanBuilder::from(plan1)
.cross_join(plan2)?
.filter(binary_expr(
binary_expr(
binary_expr(
col("t3.a").eq(col("t1.a")),
And,
col("t4.c").lt(lit(15u32)),
),
Or,
binary_expr(
col("t3.a").eq(col("t1.a")),
And,
col("t4.c").eq(lit(688u32)),
),
),
And,
binary_expr(
binary_expr(
binary_expr(
col("t3.a").eq(col("t4.a")),
And,
col("t4.c").lt(lit(15u32)),
),
Or,
binary_expr(
col("t3.a").eq(col("t4.a")),
And,
col("t3.c").eq(lit(688u32)),
),
),
Or,
binary_expr(
col("t3.a").eq(col("t4.a")),
And,
col("t3.b").eq(col("t4.b")),
),
),
))?
.build()?;
let expected = vec![
"Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Filter: t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]",
];
assert_optimized_plan_eq(plan, expected);
Ok(())
}
#[test]
fn eliminate_cross_join_multi_tables_5() -> Result<()> {
let t1 = test_table_scan_with_name("t1")?;
let t2 = test_table_scan_with_name("t2")?;
let t3 = test_table_scan_with_name("t3")?;
let t4 = test_table_scan_with_name("t4")?;
let plan1 = LogicalPlanBuilder::from(t1).cross_join(t2)?.build()?;
let plan2 = LogicalPlanBuilder::from(t3).cross_join(t4)?.build()?;
let plan = LogicalPlanBuilder::from(plan1)
.cross_join(plan2)?
.filter(binary_expr(
binary_expr(
binary_expr(
binary_expr(
col("t3.a").eq(col("t1.a")),
And,
col("t4.c").lt(lit(15u32)),
),
Or,
binary_expr(
col("t3.a").eq(col("t1.a")),
And,
col("t4.c").eq(lit(688u32)),
),
),
And,
binary_expr(
binary_expr(
binary_expr(
col("t3.a").eq(col("t4.a")),
And,
col("t4.c").lt(lit(15u32)),
),
Or,
binary_expr(
col("t3.a").eq(col("t4.a")),
And,
col("t3.c").eq(lit(688u32)),
),
),
Or,
binary_expr(
col("t3.a").eq(col("t4.a")),
And,
col("t3.b").eq(col("t4.b")),
),
),
),
And,
binary_expr(
binary_expr(
col("t1.a").eq(col("t2.a")),
Or,
col("t2.c").lt(lit(15u32)),
),
And,
binary_expr(
col("t1.a").eq(col("t2.a")),
And,
col("t2.c").eq(lit(688u32)),
),
),
))?
.build()?;
let expected = vec![
"Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]",
];
assert_optimized_plan_eq(plan, expected);
Ok(())
}
#[test]
fn eliminate_cross_join_with_expr_and() -> Result<()> {
let t1 = test_table_scan_with_name("t1")?;
let t2 = test_table_scan_with_name("t2")?;
let plan = LogicalPlanBuilder::from(t1)
.cross_join(t2)?
.filter(binary_expr(
(col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32)),
And,
col("t2.c").lt(lit(20u32)),
))?
.build()?;
let expected = vec![
"Filter: t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"];
assert_optimized_plan_eq(plan, expected);
Ok(())
}
#[test]
fn eliminate_cross_with_expr_or() -> Result<()> {
let t1 = test_table_scan_with_name("t1")?;
let t2 = test_table_scan_with_name("t2")?;
let plan = LogicalPlanBuilder::from(t1)
.cross_join(t2)?
.filter(binary_expr(
(col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32)),
Or,
col("t2.b").eq(col("t1.a")),
))?
.build()?;
let expected = vec![
"Filter: t1.a + UInt32(100) = t2.a * UInt32(2) OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
];
assert_optimized_plan_eq(plan, expected);
Ok(())
}
#[test]
fn eliminate_cross_with_common_expr_and() -> Result<()> {
let t1 = test_table_scan_with_name("t1")?;
let t2 = test_table_scan_with_name("t2")?;
let common_join_key = (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32));
let plan = LogicalPlanBuilder::from(t1)
.cross_join(t2)?
.filter(binary_expr(
binary_expr(common_join_key.clone(), And, col("t2.c").lt(lit(20u32))),
And,
binary_expr(common_join_key, And, col("t2.c").eq(lit(10u32))),
))?
.build()?;
let expected = vec![
"Filter: t2.c < UInt32(20) AND t2.c = UInt32(10) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
];
assert_optimized_plan_eq(plan, expected);
Ok(())
}
#[test]
fn eliminate_cross_with_common_expr_or() -> Result<()> {
let t1 = test_table_scan_with_name("t1")?;
let t2 = test_table_scan_with_name("t2")?;
let common_join_key = (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32));
let plan = LogicalPlanBuilder::from(t1)
.cross_join(t2)?
.filter(binary_expr(
binary_expr(common_join_key.clone(), And, col("t2.c").lt(lit(15u32))),
Or,
binary_expr(common_join_key, And, col("t2.c").eq(lit(688u32))),
))?
.build()?;
let expected = vec![
"Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
];
assert_optimized_plan_eq(plan, expected);
Ok(())
}
#[test]
fn reorder_join_with_expr_key_multi_tables() -> Result<()> {
let t1 = test_table_scan_with_name("t1")?;
let t2 = test_table_scan_with_name("t2")?;
let t3 = test_table_scan_with_name("t3")?;
let plan = LogicalPlanBuilder::from(t1)
.cross_join(t2)?
.cross_join(t3)?
.filter(binary_expr(
binary_expr(
(col("t3.a") + lit(100u32)).eq(col("t1.a") * lit(2u32)),
And,
col("t3.c").lt(lit(15u32)),
),
And,
binary_expr(
(col("t3.a") + lit(100u32)).eq(col("t2.a") * lit(2u32)),
And,
col("t3.b").lt(lit(15u32)),
),
))?
.build()?;
let expected = vec![
"Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Inner Join: t3.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Inner Join: t1.a * UInt32(2) = t3.a + UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
];
assert_optimized_plan_eq(plan, expected);
Ok(())
}
}