datafusion_optimizer/
eliminate_cross_join.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`EliminateCrossJoin`] converts `CROSS JOIN` to `INNER JOIN` if join predicates are available.
19use crate::{OptimizerConfig, OptimizerRule};
20use std::sync::Arc;
21
22use crate::join_key_set::JoinKeySet;
23use datafusion_common::tree_node::{Transformed, TreeNode};
24use datafusion_common::Result;
25use datafusion_expr::expr::{BinaryExpr, Expr};
26use datafusion_expr::logical_plan::{
27    Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection,
28};
29use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair};
30use datafusion_expr::{and, build_join_schema, ExprSchemable, Operator};
31
32#[derive(Default, Debug)]
33pub struct EliminateCrossJoin;
34
35impl EliminateCrossJoin {
36    #[allow(missing_docs)]
37    pub fn new() -> Self {
38        Self {}
39    }
40}
41
42/// Eliminate cross joins by rewriting them to inner joins when possible.
43///
44/// # Example
45/// The initial plan for this query:
46/// ```sql
47/// select ... from a, b where a.x = b.y and b.xx = 100;
48/// ```
49///
50/// Looks like this:
51/// ```text
52/// Filter(a.x = b.y AND b.xx = 100)
53///  Cross Join
54///   TableScan a
55///   TableScan b
56/// ```
57///
58/// After the rule is applied, the plan will look like this:
59/// ```text
60/// Filter(b.xx = 100)
61///   InnerJoin(a.x = b.y)
62///     TableScan a
63///     TableScan b
64/// ```
65///
66/// # Other Examples
67/// * 'select ... from a, b where a.x = b.y and b.xx = 100;'
68/// * 'select ... from a, b where (a.x = b.y and b.xx = 100) or (a.x = b.y and b.xx = 200);'
69/// * 'select ... from a, b, c where (a.x = b.y and b.xx = 100 and a.z = c.z)
70/// * or (a.x = b.y and b.xx = 200 and a.z=c.z);'
71/// * 'select ... from a, b where a.x > b.y'
72///
73/// For above queries, the join predicate is available in filters and they are moved to
74/// join nodes appropriately
75///
76/// This fix helps to improve the performance of TPCH Q19. issue#78
77impl OptimizerRule for EliminateCrossJoin {
78    fn supports_rewrite(&self) -> bool {
79        true
80    }
81
82    #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
83    fn rewrite(
84        &self,
85        plan: LogicalPlan,
86        config: &dyn OptimizerConfig,
87    ) -> Result<Transformed<LogicalPlan>> {
88        let plan_schema = Arc::clone(plan.schema());
89        let mut possible_join_keys = JoinKeySet::new();
90        let mut all_inputs: Vec<LogicalPlan> = vec![];
91        let mut all_filters: Vec<Expr> = vec![];
92
93        let parent_predicate = if let LogicalPlan::Filter(filter) = plan {
94            // if input isn't a join that can potentially be rewritten
95            // avoid unwrapping the input
96            let rewritable = matches!(
97                filter.input.as_ref(),
98                LogicalPlan::Join(Join {
99                    join_type: JoinType::Inner,
100                    ..
101                })
102            );
103
104            if !rewritable {
105                // recursively try to rewrite children
106                return rewrite_children(self, LogicalPlan::Filter(filter), config);
107            }
108
109            if !can_flatten_join_inputs(&filter.input) {
110                return Ok(Transformed::no(LogicalPlan::Filter(filter)));
111            }
112
113            let Filter {
114                input, predicate, ..
115            } = filter;
116            flatten_join_inputs(
117                Arc::unwrap_or_clone(input),
118                &mut possible_join_keys,
119                &mut all_inputs,
120                &mut all_filters,
121            )?;
122
123            extract_possible_join_keys(&predicate, &mut possible_join_keys);
124            Some(predicate)
125        } else if matches!(
126            plan,
127            LogicalPlan::Join(Join {
128                join_type: JoinType::Inner,
129                ..
130            })
131        ) {
132            if !can_flatten_join_inputs(&plan) {
133                return Ok(Transformed::no(plan));
134            }
135            flatten_join_inputs(
136                plan,
137                &mut possible_join_keys,
138                &mut all_inputs,
139                &mut all_filters,
140            )?;
141            None
142        } else {
143            // recursively try to rewrite children
144            return rewrite_children(self, plan, config);
145        };
146
147        // Join keys are handled locally:
148        let mut all_join_keys = JoinKeySet::new();
149        let mut left = all_inputs.remove(0);
150        while !all_inputs.is_empty() {
151            left = find_inner_join(
152                left,
153                &mut all_inputs,
154                &possible_join_keys,
155                &mut all_join_keys,
156            )?;
157        }
158
159        left = rewrite_children(self, left, config)?.data;
160
161        if &plan_schema != left.schema() {
162            left = LogicalPlan::Projection(Projection::new_from_schema(
163                Arc::new(left),
164                Arc::clone(&plan_schema),
165            ));
166        }
167
168        if !all_filters.is_empty() {
169            // Add any filters on top - PushDownFilter can push filters down to applicable join
170            let first = all_filters.swap_remove(0);
171            let predicate = all_filters.into_iter().fold(first, and);
172            left = LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(left))?);
173        }
174
175        let Some(predicate) = parent_predicate else {
176            return Ok(Transformed::yes(left));
177        };
178
179        // If there are no join keys then do nothing:
180        if all_join_keys.is_empty() {
181            Filter::try_new(predicate, Arc::new(left))
182                .map(|filter| Transformed::yes(LogicalPlan::Filter(filter)))
183        } else {
184            // Remove join expressions from filter:
185            match remove_join_expressions(predicate, &all_join_keys) {
186                Some(filter_expr) => Filter::try_new(filter_expr, Arc::new(left))
187                    .map(|filter| Transformed::yes(LogicalPlan::Filter(filter))),
188                _ => Ok(Transformed::yes(left)),
189            }
190        }
191    }
192
193    fn name(&self) -> &str {
194        "eliminate_cross_join"
195    }
196}
197
198fn rewrite_children(
199    optimizer: &impl OptimizerRule,
200    plan: LogicalPlan,
201    config: &dyn OptimizerConfig,
202) -> Result<Transformed<LogicalPlan>> {
203    let transformed_plan = plan.map_children(|input| optimizer.rewrite(input, config))?;
204
205    // recompute schema if the plan was transformed
206    if transformed_plan.transformed {
207        transformed_plan.map_data(|plan| plan.recompute_schema())
208    } else {
209        Ok(transformed_plan)
210    }
211}
212
213/// Recursively accumulate possible_join_keys and inputs from inner joins
214/// (including cross joins).
215///
216/// Assumes can_flatten_join_inputs has returned true and thus the plan can be
217/// flattened. Adds all leaf inputs to `all_inputs` and join_keys to
218/// possible_join_keys
219fn flatten_join_inputs(
220    plan: LogicalPlan,
221    possible_join_keys: &mut JoinKeySet,
222    all_inputs: &mut Vec<LogicalPlan>,
223    all_filters: &mut Vec<Expr>,
224) -> Result<()> {
225    match plan {
226        LogicalPlan::Join(join) if join.join_type == JoinType::Inner => {
227            if let Some(filter) = join.filter {
228                all_filters.push(filter);
229            }
230            possible_join_keys.insert_all_owned(join.on);
231            flatten_join_inputs(
232                Arc::unwrap_or_clone(join.left),
233                possible_join_keys,
234                all_inputs,
235                all_filters,
236            )?;
237            flatten_join_inputs(
238                Arc::unwrap_or_clone(join.right),
239                possible_join_keys,
240                all_inputs,
241                all_filters,
242            )?;
243        }
244        _ => {
245            all_inputs.push(plan);
246        }
247    };
248    Ok(())
249}
250
251/// Returns true if the plan is a Join or Cross join could be flattened with
252/// `flatten_join_inputs`
253///
254/// Must stay in sync with `flatten_join_inputs`
255fn can_flatten_join_inputs(plan: &LogicalPlan) -> bool {
256    // can only flatten inner / cross joins
257    match plan {
258        LogicalPlan::Join(join) if join.join_type == JoinType::Inner => {}
259        _ => return false,
260    };
261
262    for child in plan.inputs() {
263        if let LogicalPlan::Join(Join {
264            join_type: JoinType::Inner,
265            ..
266        }) = child
267        {
268            if !can_flatten_join_inputs(child) {
269                return false;
270            }
271        }
272    }
273    true
274}
275
276/// Finds the next to join with the left input plan,
277///
278/// Finds the next `right` from `rights` that can be joined with `left_input`
279/// plan based on the join keys in `possible_join_keys`.
280///
281/// If such a matching `right` is found:
282/// 1. Adds the matching join keys to `all_join_keys`.
283/// 2. Returns `left_input JOIN right ON (all join keys)`.
284///
285/// If no matching `right` is found:
286/// 1. Removes the first plan from `rights`
287/// 2. Returns `left_input CROSS JOIN right`.
288fn find_inner_join(
289    left_input: LogicalPlan,
290    rights: &mut Vec<LogicalPlan>,
291    possible_join_keys: &JoinKeySet,
292    all_join_keys: &mut JoinKeySet,
293) -> Result<LogicalPlan> {
294    for (i, right_input) in rights.iter().enumerate() {
295        let mut join_keys = vec![];
296
297        for (l, r) in possible_join_keys.iter() {
298            let key_pair = find_valid_equijoin_key_pair(
299                l,
300                r,
301                left_input.schema(),
302                right_input.schema(),
303            )?;
304
305            // Save join keys
306            if let Some((valid_l, valid_r)) = key_pair {
307                if can_hash(&valid_l.get_type(left_input.schema())?) {
308                    join_keys.push((valid_l, valid_r));
309                }
310            }
311        }
312
313        // Found one or more matching join keys
314        if !join_keys.is_empty() {
315            all_join_keys.insert_all(join_keys.iter());
316            let right_input = rights.remove(i);
317            let join_schema = Arc::new(build_join_schema(
318                left_input.schema(),
319                right_input.schema(),
320                &JoinType::Inner,
321            )?);
322
323            return Ok(LogicalPlan::Join(Join {
324                left: Arc::new(left_input),
325                right: Arc::new(right_input),
326                join_type: JoinType::Inner,
327                join_constraint: JoinConstraint::On,
328                on: join_keys,
329                filter: None,
330                schema: join_schema,
331                null_equals_null: false,
332            }));
333        }
334    }
335
336    // no matching right plan had any join keys, cross join with the first right
337    // plan
338    let right = rights.remove(0);
339    let join_schema = Arc::new(build_join_schema(
340        left_input.schema(),
341        right.schema(),
342        &JoinType::Inner,
343    )?);
344
345    Ok(LogicalPlan::Join(Join {
346        left: Arc::new(left_input),
347        right: Arc::new(right),
348        schema: join_schema,
349        on: vec![],
350        filter: None,
351        join_type: JoinType::Inner,
352        join_constraint: JoinConstraint::On,
353        null_equals_null: false,
354    }))
355}
356
357/// Extract join keys from a WHERE clause
358fn extract_possible_join_keys(expr: &Expr, join_keys: &mut JoinKeySet) {
359    if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr {
360        match op {
361            Operator::Eq => {
362                // insert handles ensuring  we don't add the same Join keys multiple times
363                join_keys.insert(left, right);
364            }
365            Operator::And => {
366                extract_possible_join_keys(left, join_keys);
367                extract_possible_join_keys(right, join_keys)
368            }
369            // Fix for issue#78 join predicates from inside of OR expr also pulled up properly.
370            Operator::Or => {
371                let mut left_join_keys = JoinKeySet::new();
372                let mut right_join_keys = JoinKeySet::new();
373
374                extract_possible_join_keys(left, &mut left_join_keys);
375                extract_possible_join_keys(right, &mut right_join_keys);
376
377                join_keys.insert_intersection(&left_join_keys, &right_join_keys)
378            }
379            _ => (),
380        };
381    }
382}
383
384/// Remove join expressions from a filter expression
385///
386/// # Returns
387/// * `Some()` when there are few remaining predicates in filter_expr
388/// * `None` otherwise
389fn remove_join_expressions(expr: Expr, join_keys: &JoinKeySet) -> Option<Expr> {
390    match expr {
391        Expr::BinaryExpr(BinaryExpr {
392            left,
393            op: Operator::Eq,
394            right,
395        }) if join_keys.contains(&left, &right) => {
396            // was a join key, so remove it
397            None
398        }
399        // Fix for issue#78 join predicates from inside of OR expr also pulled up properly.
400        Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == Operator::And => {
401            let l = remove_join_expressions(*left, join_keys);
402            let r = remove_join_expressions(*right, join_keys);
403            match (l, r) {
404                (Some(ll), Some(rr)) => Some(Expr::BinaryExpr(BinaryExpr::new(
405                    Box::new(ll),
406                    op,
407                    Box::new(rr),
408                ))),
409                (Some(ll), _) => Some(ll),
410                (_, Some(rr)) => Some(rr),
411                _ => None,
412            }
413        }
414        Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == Operator::Or => {
415            let l = remove_join_expressions(*left, join_keys);
416            let r = remove_join_expressions(*right, join_keys);
417            match (l, r) {
418                (Some(ll), Some(rr)) => Some(Expr::BinaryExpr(BinaryExpr::new(
419                    Box::new(ll),
420                    op,
421                    Box::new(rr),
422                ))),
423                // When either `left` or `right` is empty, it means they are `true`
424                // so OR'ing anything with them will also be true
425                _ => None,
426            }
427        }
428        _ => Some(expr),
429    }
430}
431
432#[cfg(test)]
433mod tests {
434    use super::*;
435    use crate::optimizer::OptimizerContext;
436    use crate::test::*;
437
438    use datafusion_expr::{
439        binary_expr, col, lit,
440        logical_plan::builder::LogicalPlanBuilder,
441        Operator::{And, Or},
442    };
443
444    fn assert_optimized_plan_eq(plan: LogicalPlan, expected: Vec<&str>) {
445        let starting_schema = Arc::clone(plan.schema());
446        let rule = EliminateCrossJoin::new();
447        let transformed_plan = rule.rewrite(plan, &OptimizerContext::new()).unwrap();
448        assert!(transformed_plan.transformed, "failed to optimize plan");
449        let optimized_plan = transformed_plan.data;
450        let formatted = optimized_plan.display_indent_schema().to_string();
451        let actual: Vec<&str> = formatted.trim().lines().collect();
452
453        assert_eq!(
454            expected, actual,
455            "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
456        );
457
458        assert_eq!(&starting_schema, optimized_plan.schema())
459    }
460
461    #[test]
462    fn eliminate_cross_with_simple_and() -> Result<()> {
463        let t1 = test_table_scan_with_name("t1")?;
464        let t2 = test_table_scan_with_name("t2")?;
465
466        // could eliminate to inner join since filter has Join predicates
467        let plan = LogicalPlanBuilder::from(t1)
468            .cross_join(t2)?
469            .filter(binary_expr(
470                col("t1.a").eq(col("t2.a")),
471                And,
472                col("t2.c").lt(lit(20u32)),
473            ))?
474            .build()?;
475
476        let expected = vec![
477            "Filter: t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
478            "  Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
479            "    TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
480            "    TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
481        ];
482
483        assert_optimized_plan_eq(plan, expected);
484
485        Ok(())
486    }
487
488    #[test]
489    fn eliminate_cross_with_simple_or() -> Result<()> {
490        let t1 = test_table_scan_with_name("t1")?;
491        let t2 = test_table_scan_with_name("t2")?;
492
493        // could not eliminate to inner join since filter OR expression and there is no common
494        // Join predicates in left and right of OR expr.
495        let plan = LogicalPlanBuilder::from(t1)
496            .cross_join(t2)?
497            .filter(binary_expr(
498                col("t1.a").eq(col("t2.a")),
499                Or,
500                col("t2.b").eq(col("t1.a")),
501            ))?
502            .build()?;
503
504        let expected = vec![
505            "Filter: t1.a = t2.a OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
506            "  Cross Join:  [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
507            "    TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
508            "    TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
509        ];
510
511        assert_optimized_plan_eq(plan, expected);
512
513        Ok(())
514    }
515
516    #[test]
517    fn eliminate_cross_with_and() -> Result<()> {
518        let t1 = test_table_scan_with_name("t1")?;
519        let t2 = test_table_scan_with_name("t2")?;
520
521        // could eliminate to inner join
522        let plan = LogicalPlanBuilder::from(t1)
523            .cross_join(t2)?
524            .filter(binary_expr(
525                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(20u32))),
526                And,
527                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").eq(lit(10u32))),
528            ))?
529            .build()?;
530
531        let expected = vec![
532            "Filter: t2.c < UInt32(20) AND t2.c = UInt32(10) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
533            "  Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
534            "    TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
535            "    TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
536        ];
537
538        assert_optimized_plan_eq(plan, expected);
539
540        Ok(())
541    }
542
543    #[test]
544    fn eliminate_cross_with_or() -> Result<()> {
545        let t1 = test_table_scan_with_name("t1")?;
546        let t2 = test_table_scan_with_name("t2")?;
547
548        // could eliminate to inner join since Or predicates have common Join predicates
549        let plan = LogicalPlanBuilder::from(t1)
550            .cross_join(t2)?
551            .filter(binary_expr(
552                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
553                Or,
554                binary_expr(
555                    col("t1.a").eq(col("t2.a")),
556                    And,
557                    col("t2.c").eq(lit(688u32)),
558                ),
559            ))?
560            .build()?;
561
562        let expected = vec![
563            "Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
564            "  Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
565            "    TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
566            "    TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
567        ];
568        assert_optimized_plan_eq(plan, expected);
569
570        Ok(())
571    }
572
573    #[test]
574    fn eliminate_cross_not_possible_simple() -> Result<()> {
575        let t1 = test_table_scan_with_name("t1")?;
576        let t2 = test_table_scan_with_name("t2")?;
577
578        // could not eliminate to inner join
579        let plan = LogicalPlanBuilder::from(t1)
580            .cross_join(t2)?
581            .filter(binary_expr(
582                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
583                Or,
584                binary_expr(
585                    col("t1.b").eq(col("t2.b")),
586                    And,
587                    col("t2.c").eq(lit(688u32)),
588                ),
589            ))?
590            .build()?;
591
592        let expected = vec![
593            "Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.b = t2.b AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
594            "  Cross Join:  [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
595            "    TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
596            "    TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
597        ];
598        assert_optimized_plan_eq(plan, expected);
599
600        Ok(())
601    }
602
603    #[test]
604    fn eliminate_cross_not_possible() -> Result<()> {
605        let t1 = test_table_scan_with_name("t1")?;
606        let t2 = test_table_scan_with_name("t2")?;
607
608        // could not eliminate to inner join
609        let plan = LogicalPlanBuilder::from(t1)
610            .cross_join(t2)?
611            .filter(binary_expr(
612                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
613                Or,
614                binary_expr(col("t1.a").eq(col("t2.a")), Or, col("t2.c").eq(lit(688u32))),
615            ))?
616            .build()?;
617
618        let expected = vec![
619            "Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.a = t2.a OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
620            "  Cross Join:  [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
621            "    TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
622            "    TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
623        ];
624        assert_optimized_plan_eq(plan, expected);
625
626        Ok(())
627    }
628
629    #[test]
630    fn eliminate_cross_possible_nested_inner_join_with_filter() -> Result<()> {
631        let t1 = test_table_scan_with_name("t1")?;
632        let t2 = test_table_scan_with_name("t2")?;
633        let t3 = test_table_scan_with_name("t3")?;
634
635        // could not eliminate to inner join with filter
636        let plan = LogicalPlanBuilder::from(t1)
637            .join(
638                t3,
639                JoinType::Inner,
640                (vec!["t1.a"], vec!["t3.a"]),
641                Some(col("t1.a").gt(lit(20u32))),
642            )?
643            .join(t2, JoinType::Inner, (vec!["t1.a"], vec!["t2.a"]), None)?
644            .filter(col("t1.a").gt(lit(15u32)))?
645            .build()?;
646
647        let expected = vec![
648            "Filter: t1.a > UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
649            "  Filter: t1.a > UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
650            "    Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
651            "      Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
652            "        TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
653            "        TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]",
654            "      TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"
655        ];
656
657        assert_optimized_plan_eq(plan, expected);
658
659        Ok(())
660    }
661
662    #[test]
663    /// ```txt
664    /// filter: a.id = b.id and a.id = c.id
665    ///   cross_join a (bc)
666    ///     cross_join b c
667    /// ```
668    /// Without reorder, it will be
669    /// ```txt
670    ///   inner_join a (bc) on a.id = b.id and a.id = c.id
671    ///     cross_join b c
672    /// ```
673    /// Reorder it to be
674    /// ```txt
675    ///   inner_join (ab)c and a.id = c.id
676    ///     inner_join a b on a.id = b.id
677    /// ```
678    fn reorder_join_to_eliminate_cross_join_multi_tables() -> Result<()> {
679        let t1 = test_table_scan_with_name("t1")?;
680        let t2 = test_table_scan_with_name("t2")?;
681        let t3 = test_table_scan_with_name("t3")?;
682
683        // could eliminate to inner join
684        let plan = LogicalPlanBuilder::from(t1)
685            .cross_join(t2)?
686            .cross_join(t3)?
687            .filter(binary_expr(
688                binary_expr(col("t3.a").eq(col("t1.a")), And, col("t3.c").lt(lit(15u32))),
689                And,
690                binary_expr(col("t3.a").eq(col("t2.a")), And, col("t3.b").lt(lit(15u32))),
691            ))?
692            .build()?;
693
694        let expected = vec![
695            "Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
696            "  Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
697            "    Inner Join: t3.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
698            "      Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
699            "        TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
700            "        TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]",
701            "      TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
702        ];
703
704        assert_optimized_plan_eq(plan, expected);
705
706        Ok(())
707    }
708
709    #[test]
710    fn eliminate_cross_join_multi_tables() -> Result<()> {
711        let t1 = test_table_scan_with_name("t1")?;
712        let t2 = test_table_scan_with_name("t2")?;
713        let t3 = test_table_scan_with_name("t3")?;
714        let t4 = test_table_scan_with_name("t4")?;
715
716        // could eliminate to inner join
717        let plan1 = LogicalPlanBuilder::from(t1)
718            .cross_join(t2)?
719            .filter(binary_expr(
720                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
721                Or,
722                binary_expr(
723                    col("t1.a").eq(col("t2.a")),
724                    And,
725                    col("t2.c").eq(lit(688u32)),
726                ),
727            ))?
728            .build()?;
729
730        let plan2 = LogicalPlanBuilder::from(t3)
731            .cross_join(t4)?
732            .filter(binary_expr(
733                binary_expr(
734                    binary_expr(
735                        col("t3.a").eq(col("t4.a")),
736                        And,
737                        col("t4.c").lt(lit(15u32)),
738                    ),
739                    Or,
740                    binary_expr(
741                        col("t3.a").eq(col("t4.a")),
742                        And,
743                        col("t3.c").eq(lit(688u32)),
744                    ),
745                ),
746                Or,
747                binary_expr(
748                    col("t3.a").eq(col("t4.a")),
749                    And,
750                    col("t3.b").eq(col("t4.b")),
751                ),
752            ))?
753            .build()?;
754
755        let plan = LogicalPlanBuilder::from(plan1)
756            .cross_join(plan2)?
757            .filter(binary_expr(
758                binary_expr(col("t3.a").eq(col("t1.a")), And, col("t4.c").lt(lit(15u32))),
759                Or,
760                binary_expr(
761                    col("t3.a").eq(col("t1.a")),
762                    And,
763                    col("t4.c").eq(lit(688u32)),
764                ),
765            ))?
766            .build()?;
767
768        let expected = vec![
769            "Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
770            "  Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
771            "    Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
772            "      Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
773            "        TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
774            "        TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
775            "    Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
776            "      Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
777            "        TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]",
778            "        TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]",
779        ];
780
781        assert_optimized_plan_eq(plan, expected);
782
783        Ok(())
784    }
785
786    #[test]
787    fn eliminate_cross_join_multi_tables_1() -> Result<()> {
788        let t1 = test_table_scan_with_name("t1")?;
789        let t2 = test_table_scan_with_name("t2")?;
790        let t3 = test_table_scan_with_name("t3")?;
791        let t4 = test_table_scan_with_name("t4")?;
792
793        // could eliminate to inner join
794        let plan1 = LogicalPlanBuilder::from(t1)
795            .cross_join(t2)?
796            .filter(binary_expr(
797                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
798                Or,
799                binary_expr(
800                    col("t1.a").eq(col("t2.a")),
801                    And,
802                    col("t2.c").eq(lit(688u32)),
803                ),
804            ))?
805            .build()?;
806
807        // could eliminate to inner join
808        let plan2 = LogicalPlanBuilder::from(t3)
809            .cross_join(t4)?
810            .filter(binary_expr(
811                binary_expr(
812                    binary_expr(
813                        col("t3.a").eq(col("t4.a")),
814                        And,
815                        col("t4.c").lt(lit(15u32)),
816                    ),
817                    Or,
818                    binary_expr(
819                        col("t3.a").eq(col("t4.a")),
820                        And,
821                        col("t3.c").eq(lit(688u32)),
822                    ),
823                ),
824                Or,
825                binary_expr(
826                    col("t3.a").eq(col("t4.a")),
827                    And,
828                    col("t3.b").eq(col("t4.b")),
829                ),
830            ))?
831            .build()?;
832
833        // could not eliminate to inner join
834        let plan = LogicalPlanBuilder::from(plan1)
835            .cross_join(plan2)?
836            .filter(binary_expr(
837                binary_expr(col("t3.a").eq(col("t1.a")), And, col("t4.c").lt(lit(15u32))),
838                Or,
839                binary_expr(col("t3.a").eq(col("t1.a")), Or, col("t4.c").eq(lit(688u32))),
840            ))?
841            .build()?;
842
843        let expected = vec![
844            "Filter: t3.a = t1.a AND t4.c < UInt32(15) OR t3.a = t1.a OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
845            "  Cross Join:  [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
846            "    Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
847            "      Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
848            "        TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
849            "        TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
850            "    Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
851            "      Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
852            "        TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]",
853            "        TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]",
854        ];
855
856        assert_optimized_plan_eq(plan, expected);
857
858        Ok(())
859    }
860
861    #[test]
862    fn eliminate_cross_join_multi_tables_2() -> Result<()> {
863        let t1 = test_table_scan_with_name("t1")?;
864        let t2 = test_table_scan_with_name("t2")?;
865        let t3 = test_table_scan_with_name("t3")?;
866        let t4 = test_table_scan_with_name("t4")?;
867
868        // could eliminate to inner join
869        let plan1 = LogicalPlanBuilder::from(t1)
870            .cross_join(t2)?
871            .filter(binary_expr(
872                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
873                Or,
874                binary_expr(
875                    col("t1.a").eq(col("t2.a")),
876                    And,
877                    col("t2.c").eq(lit(688u32)),
878                ),
879            ))?
880            .build()?;
881
882        // could not eliminate to inner join
883        let plan2 = LogicalPlanBuilder::from(t3)
884            .cross_join(t4)?
885            .filter(binary_expr(
886                binary_expr(
887                    binary_expr(
888                        col("t3.a").eq(col("t4.a")),
889                        And,
890                        col("t4.c").lt(lit(15u32)),
891                    ),
892                    Or,
893                    binary_expr(
894                        col("t3.a").eq(col("t4.a")),
895                        And,
896                        col("t3.c").eq(lit(688u32)),
897                    ),
898                ),
899                Or,
900                binary_expr(col("t3.a").eq(col("t4.a")), Or, col("t3.b").eq(col("t4.b"))),
901            ))?
902            .build()?;
903
904        // could eliminate to inner join
905        let plan = LogicalPlanBuilder::from(plan1)
906            .cross_join(plan2)?
907            .filter(binary_expr(
908                binary_expr(col("t3.a").eq(col("t1.a")), And, col("t4.c").lt(lit(15u32))),
909                Or,
910                binary_expr(
911                    col("t3.a").eq(col("t1.a")),
912                    And,
913                    col("t4.c").eq(lit(688u32)),
914                ),
915            ))?
916            .build()?;
917
918        let expected = vec![
919            "Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
920            "  Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
921            "    Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
922            "      Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
923            "        TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
924            "        TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
925            "    Filter: t3.a = t4.a AND t4.c < UInt32(15) OR t3.a = t4.a AND t3.c = UInt32(688) OR t3.a = t4.a OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
926            "      Cross Join:  [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
927            "        TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]",
928            "        TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]",
929        ];
930
931        assert_optimized_plan_eq(plan, expected);
932
933        Ok(())
934    }
935
936    #[test]
937    fn eliminate_cross_join_multi_tables_3() -> Result<()> {
938        let t1 = test_table_scan_with_name("t1")?;
939        let t2 = test_table_scan_with_name("t2")?;
940        let t3 = test_table_scan_with_name("t3")?;
941        let t4 = test_table_scan_with_name("t4")?;
942
943        // could not eliminate to inner join
944        let plan1 = LogicalPlanBuilder::from(t1)
945            .cross_join(t2)?
946            .filter(binary_expr(
947                binary_expr(col("t1.a").eq(col("t2.a")), Or, col("t2.c").lt(lit(15u32))),
948                Or,
949                binary_expr(
950                    col("t1.a").eq(col("t2.a")),
951                    And,
952                    col("t2.c").eq(lit(688u32)),
953                ),
954            ))?
955            .build()?;
956
957        // could eliminate to inner join
958        let plan2 = LogicalPlanBuilder::from(t3)
959            .cross_join(t4)?
960            .filter(binary_expr(
961                binary_expr(
962                    binary_expr(
963                        col("t3.a").eq(col("t4.a")),
964                        And,
965                        col("t4.c").lt(lit(15u32)),
966                    ),
967                    Or,
968                    binary_expr(
969                        col("t3.a").eq(col("t4.a")),
970                        And,
971                        col("t3.c").eq(lit(688u32)),
972                    ),
973                ),
974                Or,
975                binary_expr(
976                    col("t3.a").eq(col("t4.a")),
977                    And,
978                    col("t3.b").eq(col("t4.b")),
979                ),
980            ))?
981            .build()?;
982
983        // could eliminate to inner join
984        let plan = LogicalPlanBuilder::from(plan1)
985            .cross_join(plan2)?
986            .filter(binary_expr(
987                binary_expr(col("t3.a").eq(col("t1.a")), And, col("t4.c").lt(lit(15u32))),
988                Or,
989                binary_expr(
990                    col("t3.a").eq(col("t1.a")),
991                    And,
992                    col("t4.c").eq(lit(688u32)),
993                ),
994            ))?
995            .build()?;
996
997        let expected = vec![
998            "Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
999            "  Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
1000            "    Filter: t1.a = t2.a OR t2.c < UInt32(15) OR t1.a = t2.a AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
1001            "      Cross Join:  [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
1002            "        TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
1003            "        TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
1004            "    Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
1005            "      Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
1006            "        TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]",
1007            "        TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]",
1008        ];
1009
1010        assert_optimized_plan_eq(plan, expected);
1011
1012        Ok(())
1013    }
1014
1015    #[test]
1016    fn eliminate_cross_join_multi_tables_4() -> Result<()> {
1017        let t1 = test_table_scan_with_name("t1")?;
1018        let t2 = test_table_scan_with_name("t2")?;
1019        let t3 = test_table_scan_with_name("t3")?;
1020        let t4 = test_table_scan_with_name("t4")?;
1021
1022        // could eliminate to inner join
1023        // filter: (t1.a = t2.a OR t2.c < 15) AND (t1.a = t2.a AND tc.2 = 688)
1024        let plan1 = LogicalPlanBuilder::from(t1)
1025            .cross_join(t2)?
1026            .filter(binary_expr(
1027                binary_expr(col("t1.a").eq(col("t2.a")), Or, col("t2.c").lt(lit(15u32))),
1028                And,
1029                binary_expr(
1030                    col("t1.a").eq(col("t2.a")),
1031                    And,
1032                    col("t2.c").eq(lit(688u32)),
1033                ),
1034            ))?
1035            .build()?;
1036
1037        // could eliminate to inner join
1038        let plan2 = LogicalPlanBuilder::from(t3).cross_join(t4)?.build()?;
1039
1040        // could eliminate to inner join
1041        // filter:
1042        //   ((t3.a = t1.a AND t4.c < 15) OR (t3.a = t1.a AND t4.c = 688))
1043        //     AND
1044        //   ((t3.a = t4.a AND t4.c < 15) OR (t3.a = t4.a AND t3.c = 688) OR (t3.a = t4.a AND t3.b = t4.b))
1045        let plan = LogicalPlanBuilder::from(plan1)
1046            .cross_join(plan2)?
1047            .filter(binary_expr(
1048                binary_expr(
1049                    binary_expr(
1050                        col("t3.a").eq(col("t1.a")),
1051                        And,
1052                        col("t4.c").lt(lit(15u32)),
1053                    ),
1054                    Or,
1055                    binary_expr(
1056                        col("t3.a").eq(col("t1.a")),
1057                        And,
1058                        col("t4.c").eq(lit(688u32)),
1059                    ),
1060                ),
1061                And,
1062                binary_expr(
1063                    binary_expr(
1064                        binary_expr(
1065                            col("t3.a").eq(col("t4.a")),
1066                            And,
1067                            col("t4.c").lt(lit(15u32)),
1068                        ),
1069                        Or,
1070                        binary_expr(
1071                            col("t3.a").eq(col("t4.a")),
1072                            And,
1073                            col("t3.c").eq(lit(688u32)),
1074                        ),
1075                    ),
1076                    Or,
1077                    binary_expr(
1078                        col("t3.a").eq(col("t4.a")),
1079                        And,
1080                        col("t3.b").eq(col("t4.b")),
1081                    ),
1082                ),
1083            ))?
1084            .build()?;
1085
1086        let expected = vec![
1087            "Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
1088            "  Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
1089            "    Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
1090            "      Filter: t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
1091            "        Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
1092            "          TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
1093            "          TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
1094            "      TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]",
1095            "    TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]",
1096        ];
1097
1098        assert_optimized_plan_eq(plan, expected);
1099
1100        Ok(())
1101    }
1102
1103    #[test]
1104    fn eliminate_cross_join_multi_tables_5() -> Result<()> {
1105        let t1 = test_table_scan_with_name("t1")?;
1106        let t2 = test_table_scan_with_name("t2")?;
1107        let t3 = test_table_scan_with_name("t3")?;
1108        let t4 = test_table_scan_with_name("t4")?;
1109
1110        // could eliminate to inner join
1111        let plan1 = LogicalPlanBuilder::from(t1).cross_join(t2)?.build()?;
1112
1113        // could eliminate to inner join
1114        let plan2 = LogicalPlanBuilder::from(t3).cross_join(t4)?.build()?;
1115
1116        // could eliminate to inner join
1117        // Filter:
1118        //  ((t3.a = t1.a AND t4.c < 15) OR (t3.a = t1.a AND t4.c = 688))
1119        //      AND
1120        //  ((t3.a = t4.a AND t4.c < 15) OR (t3.a = t4.a AND t3.c = 688) OR (t3.a = t4.a AND t3.b = t4.b))
1121        //      AND
1122        //  ((t1.a = t2.a OR t2.c < 15) AND (t1.a = t2.a AND t2.c = 688))
1123        let plan = LogicalPlanBuilder::from(plan1)
1124            .cross_join(plan2)?
1125            .filter(binary_expr(
1126                binary_expr(
1127                    binary_expr(
1128                        binary_expr(
1129                            col("t3.a").eq(col("t1.a")),
1130                            And,
1131                            col("t4.c").lt(lit(15u32)),
1132                        ),
1133                        Or,
1134                        binary_expr(
1135                            col("t3.a").eq(col("t1.a")),
1136                            And,
1137                            col("t4.c").eq(lit(688u32)),
1138                        ),
1139                    ),
1140                    And,
1141                    binary_expr(
1142                        binary_expr(
1143                            binary_expr(
1144                                col("t3.a").eq(col("t4.a")),
1145                                And,
1146                                col("t4.c").lt(lit(15u32)),
1147                            ),
1148                            Or,
1149                            binary_expr(
1150                                col("t3.a").eq(col("t4.a")),
1151                                And,
1152                                col("t3.c").eq(lit(688u32)),
1153                            ),
1154                        ),
1155                        Or,
1156                        binary_expr(
1157                            col("t3.a").eq(col("t4.a")),
1158                            And,
1159                            col("t3.b").eq(col("t4.b")),
1160                        ),
1161                    ),
1162                ),
1163                And,
1164                binary_expr(
1165                    binary_expr(
1166                        col("t1.a").eq(col("t2.a")),
1167                        Or,
1168                        col("t2.c").lt(lit(15u32)),
1169                    ),
1170                    And,
1171                    binary_expr(
1172                        col("t1.a").eq(col("t2.a")),
1173                        And,
1174                        col("t2.c").eq(lit(688u32)),
1175                    ),
1176                ),
1177            ))?
1178            .build()?;
1179
1180        let expected = vec![
1181            "Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
1182            "  Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
1183            "    Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
1184            "      Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
1185            "        TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
1186            "        TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
1187            "      TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]",
1188            "    TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]",
1189        ];
1190
1191        assert_optimized_plan_eq(plan, expected);
1192
1193        Ok(())
1194    }
1195
1196    #[test]
1197    fn eliminate_cross_join_with_expr_and() -> Result<()> {
1198        let t1 = test_table_scan_with_name("t1")?;
1199        let t2 = test_table_scan_with_name("t2")?;
1200
1201        // could eliminate to inner join since filter has Join predicates
1202        let plan = LogicalPlanBuilder::from(t1)
1203            .cross_join(t2)?
1204            .filter(binary_expr(
1205                (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32)),
1206                And,
1207                col("t2.c").lt(lit(20u32)),
1208            ))?
1209            .build()?;
1210
1211        let expected = vec![
1212            "Filter: t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
1213            "  Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
1214            "    TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
1215            "    TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"];
1216
1217        assert_optimized_plan_eq(plan, expected);
1218
1219        Ok(())
1220    }
1221
1222    #[test]
1223    fn eliminate_cross_with_expr_or() -> Result<()> {
1224        let t1 = test_table_scan_with_name("t1")?;
1225        let t2 = test_table_scan_with_name("t2")?;
1226
1227        // could not eliminate to inner join since filter OR expression and there is no common
1228        // Join predicates in left and right of OR expr.
1229        let plan = LogicalPlanBuilder::from(t1)
1230            .cross_join(t2)?
1231            .filter(binary_expr(
1232                (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32)),
1233                Or,
1234                col("t2.b").eq(col("t1.a")),
1235            ))?
1236            .build()?;
1237
1238        let expected = vec![
1239            "Filter: t1.a + UInt32(100) = t2.a * UInt32(2) OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
1240            "  Cross Join:  [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
1241            "    TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
1242            "    TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
1243        ];
1244
1245        assert_optimized_plan_eq(plan, expected);
1246
1247        Ok(())
1248    }
1249
1250    #[test]
1251    fn eliminate_cross_with_common_expr_and() -> Result<()> {
1252        let t1 = test_table_scan_with_name("t1")?;
1253        let t2 = test_table_scan_with_name("t2")?;
1254
1255        // could eliminate to inner join
1256        let common_join_key = (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32));
1257        let plan = LogicalPlanBuilder::from(t1)
1258            .cross_join(t2)?
1259            .filter(binary_expr(
1260                binary_expr(common_join_key.clone(), And, col("t2.c").lt(lit(20u32))),
1261                And,
1262                binary_expr(common_join_key, And, col("t2.c").eq(lit(10u32))),
1263            ))?
1264            .build()?;
1265
1266        let expected = vec![
1267            "Filter: t2.c < UInt32(20) AND t2.c = UInt32(10) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
1268            "  Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
1269            "    TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
1270            "    TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
1271        ];
1272
1273        assert_optimized_plan_eq(plan, expected);
1274
1275        Ok(())
1276    }
1277
1278    #[test]
1279    fn eliminate_cross_with_common_expr_or() -> Result<()> {
1280        let t1 = test_table_scan_with_name("t1")?;
1281        let t2 = test_table_scan_with_name("t2")?;
1282
1283        // could eliminate to inner join since Or predicates have common Join predicates
1284        let common_join_key = (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32));
1285        let plan = LogicalPlanBuilder::from(t1)
1286            .cross_join(t2)?
1287            .filter(binary_expr(
1288                binary_expr(common_join_key.clone(), And, col("t2.c").lt(lit(15u32))),
1289                Or,
1290                binary_expr(common_join_key, And, col("t2.c").eq(lit(688u32))),
1291            ))?
1292            .build()?;
1293
1294        let expected = vec![
1295            "Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
1296            "  Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
1297            "    TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
1298            "    TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
1299        ];
1300
1301        assert_optimized_plan_eq(plan, expected);
1302
1303        Ok(())
1304    }
1305
1306    #[test]
1307    fn reorder_join_with_expr_key_multi_tables() -> Result<()> {
1308        let t1 = test_table_scan_with_name("t1")?;
1309        let t2 = test_table_scan_with_name("t2")?;
1310        let t3 = test_table_scan_with_name("t3")?;
1311
1312        // could eliminate to inner join
1313        let plan = LogicalPlanBuilder::from(t1)
1314            .cross_join(t2)?
1315            .cross_join(t3)?
1316            .filter(binary_expr(
1317                binary_expr(
1318                    (col("t3.a") + lit(100u32)).eq(col("t1.a") * lit(2u32)),
1319                    And,
1320                    col("t3.c").lt(lit(15u32)),
1321                ),
1322                And,
1323                binary_expr(
1324                    (col("t3.a") + lit(100u32)).eq(col("t2.a") * lit(2u32)),
1325                    And,
1326                    col("t3.b").lt(lit(15u32)),
1327                ),
1328            ))?
1329            .build()?;
1330
1331        let expected = vec![
1332            "Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
1333            "  Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
1334            "    Inner Join: t3.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
1335            "      Inner Join: t1.a * UInt32(2) = t3.a + UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
1336            "        TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
1337            "        TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]",
1338            "      TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
1339        ];
1340
1341        assert_optimized_plan_eq(plan, expected);
1342
1343        Ok(())
1344    }
1345}