datafusion_physical_expr/equivalence/
class.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use super::{add_offset_to_expr, ProjectionMapping};
19use crate::{
20    expressions::Column, LexOrdering, LexRequirement, PhysicalExpr, PhysicalExprRef,
21    PhysicalSortExpr, PhysicalSortRequirement,
22};
23use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
24use datafusion_common::{JoinType, ScalarValue};
25use datafusion_physical_expr_common::physical_expr::format_physical_expr_list;
26use std::fmt::Display;
27use std::sync::Arc;
28use std::vec::IntoIter;
29
30use indexmap::{IndexMap, IndexSet};
31
32/// A structure representing a expression known to be constant in a physical execution plan.
33///
34/// The `ConstExpr` struct encapsulates an expression that is constant during the execution
35/// of a query. For example if a predicate like `A = 5` applied earlier in the plan `A` would
36/// be known constant
37///
38/// # Fields
39///
40/// - `expr`: Constant expression for a node in the physical plan.
41///
42/// - `across_partitions`: A boolean flag indicating whether the constant
43///   expression is the same across partitions. If set to `true`, the constant
44///   expression has same value for all partitions. If set to `false`, the
45///   constant expression may have different values for different partitions.
46///
47/// # Example
48///
49/// ```rust
50/// # use datafusion_physical_expr::ConstExpr;
51/// # use datafusion_physical_expr::expressions::lit;
52/// let col = lit(5);
53/// // Create a constant expression from a physical expression ref
54/// let const_expr = ConstExpr::from(&col);
55/// // create a constant expression from a physical expression
56/// let const_expr = ConstExpr::from(col);
57/// ```
58// TODO: Consider refactoring the `across_partitions` and `value` fields into an enum:
59//
60// ```
61// enum PartitionValues {
62//     Uniform(Option<ScalarValue>),           // Same value across all partitions
63//     Heterogeneous(Vec<Option<ScalarValue>>) // Different values per partition
64// }
65// ```
66//
67// This would provide more flexible representation of partition values.
68// Note: This is a breaking change for the equivalence API and should be
69// addressed in a separate issue/PR.
70#[derive(Debug, Clone)]
71pub struct ConstExpr {
72    /// The  expression that is known to be constant (e.g. a `Column`)
73    expr: Arc<dyn PhysicalExpr>,
74    /// Does the constant have the same value across all partitions? See
75    /// struct docs for more details
76    across_partitions: AcrossPartitions,
77}
78
79#[derive(PartialEq, Clone, Debug)]
80/// Represents whether a constant expression's value is uniform or varies across partitions.
81///
82/// The `AcrossPartitions` enum is used to describe the nature of a constant expression
83/// in a physical execution plan:
84///
85/// - `Heterogeneous`: The constant expression may have different values for different partitions.
86/// - `Uniform(Option<ScalarValue>)`: The constant expression has the same value across all partitions,
87///   or is `None` if the value is not specified.
88pub enum AcrossPartitions {
89    Heterogeneous,
90    Uniform(Option<ScalarValue>),
91}
92
93impl Default for AcrossPartitions {
94    fn default() -> Self {
95        Self::Heterogeneous
96    }
97}
98
99impl PartialEq for ConstExpr {
100    fn eq(&self, other: &Self) -> bool {
101        self.across_partitions == other.across_partitions && self.expr.eq(&other.expr)
102    }
103}
104
105impl ConstExpr {
106    /// Create a new constant expression from a physical expression.
107    ///
108    /// Note you can also use `ConstExpr::from` to create a constant expression
109    /// from a reference as well
110    pub fn new(expr: Arc<dyn PhysicalExpr>) -> Self {
111        Self {
112            expr,
113            // By default, assume constant expressions are not same across partitions.
114            across_partitions: Default::default(),
115        }
116    }
117
118    /// Set the `across_partitions` flag
119    ///
120    /// See struct docs for more details
121    pub fn with_across_partitions(mut self, across_partitions: AcrossPartitions) -> Self {
122        self.across_partitions = across_partitions;
123        self
124    }
125
126    /// Is the  expression the same across all partitions?
127    ///
128    /// See struct docs for more details
129    pub fn across_partitions(&self) -> AcrossPartitions {
130        self.across_partitions.clone()
131    }
132
133    pub fn expr(&self) -> &Arc<dyn PhysicalExpr> {
134        &self.expr
135    }
136
137    pub fn owned_expr(self) -> Arc<dyn PhysicalExpr> {
138        self.expr
139    }
140
141    pub fn map<F>(&self, f: F) -> Option<Self>
142    where
143        F: Fn(&Arc<dyn PhysicalExpr>) -> Option<Arc<dyn PhysicalExpr>>,
144    {
145        let maybe_expr = f(&self.expr);
146        maybe_expr.map(|expr| Self {
147            expr,
148            across_partitions: self.across_partitions.clone(),
149        })
150    }
151
152    /// Returns true if this constant expression is equal to the given expression
153    pub fn eq_expr(&self, other: impl AsRef<dyn PhysicalExpr>) -> bool {
154        self.expr.as_ref() == other.as_ref()
155    }
156
157    /// Returns a [`Display`]able list of `ConstExpr`.
158    pub fn format_list(input: &[ConstExpr]) -> impl Display + '_ {
159        struct DisplayableList<'a>(&'a [ConstExpr]);
160        impl Display for DisplayableList<'_> {
161            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
162                let mut first = true;
163                for const_expr in self.0 {
164                    if first {
165                        first = false;
166                    } else {
167                        write!(f, ",")?;
168                    }
169                    write!(f, "{}", const_expr)?;
170                }
171                Ok(())
172            }
173        }
174        DisplayableList(input)
175    }
176}
177
178impl Display for ConstExpr {
179    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
180        write!(f, "{}", self.expr)?;
181        match &self.across_partitions {
182            AcrossPartitions::Heterogeneous => {
183                write!(f, "(heterogeneous)")?;
184            }
185            AcrossPartitions::Uniform(value) => {
186                if let Some(val) = value {
187                    write!(f, "(uniform: {})", val)?;
188                } else {
189                    write!(f, "(uniform: unknown)")?;
190                }
191            }
192        }
193        Ok(())
194    }
195}
196
197impl From<Arc<dyn PhysicalExpr>> for ConstExpr {
198    fn from(expr: Arc<dyn PhysicalExpr>) -> Self {
199        Self::new(expr)
200    }
201}
202
203impl From<&Arc<dyn PhysicalExpr>> for ConstExpr {
204    fn from(expr: &Arc<dyn PhysicalExpr>) -> Self {
205        Self::new(Arc::clone(expr))
206    }
207}
208
209/// Checks whether `expr` is among in the `const_exprs`.
210pub fn const_exprs_contains(
211    const_exprs: &[ConstExpr],
212    expr: &Arc<dyn PhysicalExpr>,
213) -> bool {
214    const_exprs
215        .iter()
216        .any(|const_expr| const_expr.expr.eq(expr))
217}
218
219/// An `EquivalenceClass` is a set of [`Arc<dyn PhysicalExpr>`]s that are known
220/// to have the same value for all tuples in a relation. These are generated by
221/// equality predicates (e.g. `a = b`), typically equi-join conditions and
222/// equality conditions in filters.
223///
224/// Two `EquivalenceClass`es are equal if they contains the same expressions in
225/// without any ordering.
226#[derive(Debug, Clone)]
227pub struct EquivalenceClass {
228    /// The expressions in this equivalence class. The order doesn't
229    /// matter for equivalence purposes
230    ///
231    exprs: IndexSet<Arc<dyn PhysicalExpr>>,
232}
233
234impl PartialEq for EquivalenceClass {
235    /// Returns true if other is equal in the sense
236    /// of bags (multi-sets), disregarding their orderings.
237    fn eq(&self, other: &Self) -> bool {
238        self.exprs.eq(&other.exprs)
239    }
240}
241
242impl EquivalenceClass {
243    /// Create a new empty equivalence class
244    pub fn new_empty() -> Self {
245        Self {
246            exprs: IndexSet::new(),
247        }
248    }
249
250    // Create a new equivalence class from a pre-existing `Vec`
251    pub fn new(exprs: Vec<Arc<dyn PhysicalExpr>>) -> Self {
252        Self {
253            exprs: exprs.into_iter().collect(),
254        }
255    }
256
257    /// Return the inner vector of expressions
258    pub fn into_vec(self) -> Vec<Arc<dyn PhysicalExpr>> {
259        self.exprs.into_iter().collect()
260    }
261
262    /// Return the "canonical" expression for this class (the first element)
263    /// if any
264    fn canonical_expr(&self) -> Option<Arc<dyn PhysicalExpr>> {
265        self.exprs.iter().next().cloned()
266    }
267
268    /// Insert the expression into this class, meaning it is known to be equal to
269    /// all other expressions in this class
270    pub fn push(&mut self, expr: Arc<dyn PhysicalExpr>) {
271        self.exprs.insert(expr);
272    }
273
274    /// Inserts all the expressions from other into this class
275    pub fn extend(&mut self, other: Self) {
276        for expr in other.exprs {
277            // use push so entries are deduplicated
278            self.push(expr);
279        }
280    }
281
282    /// Returns true if this equivalence class contains t expression
283    pub fn contains(&self, expr: &Arc<dyn PhysicalExpr>) -> bool {
284        self.exprs.contains(expr)
285    }
286
287    /// Returns true if this equivalence class has any entries in common with `other`
288    pub fn contains_any(&self, other: &Self) -> bool {
289        self.exprs.iter().any(|e| other.contains(e))
290    }
291
292    /// return the number of items in this class
293    pub fn len(&self) -> usize {
294        self.exprs.len()
295    }
296
297    /// return true if this class is empty
298    pub fn is_empty(&self) -> bool {
299        self.exprs.is_empty()
300    }
301
302    /// Iterate over all elements in this class, in some arbitrary order
303    pub fn iter(&self) -> impl Iterator<Item = &Arc<dyn PhysicalExpr>> {
304        self.exprs.iter()
305    }
306
307    /// Return a new equivalence class that have the specified offset added to
308    /// each expression (used when schemas are appended such as in joins)
309    pub fn with_offset(&self, offset: usize) -> Self {
310        let new_exprs = self
311            .exprs
312            .iter()
313            .cloned()
314            .map(|e| add_offset_to_expr(e, offset))
315            .collect();
316        Self::new(new_exprs)
317    }
318}
319
320impl Display for EquivalenceClass {
321    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
322        write!(f, "[{}]", format_physical_expr_list(&self.exprs))
323    }
324}
325
326/// A collection of distinct `EquivalenceClass`es
327#[derive(Debug, Clone)]
328pub struct EquivalenceGroup {
329    classes: Vec<EquivalenceClass>,
330}
331
332impl EquivalenceGroup {
333    /// Creates an empty equivalence group.
334    pub fn empty() -> Self {
335        Self { classes: vec![] }
336    }
337
338    /// Creates an equivalence group from the given equivalence classes.
339    pub fn new(classes: Vec<EquivalenceClass>) -> Self {
340        let mut result = Self { classes };
341        result.remove_redundant_entries();
342        result
343    }
344
345    /// Returns how many equivalence classes there are in this group.
346    pub fn len(&self) -> usize {
347        self.classes.len()
348    }
349
350    /// Checks whether this equivalence group is empty.
351    pub fn is_empty(&self) -> bool {
352        self.len() == 0
353    }
354
355    /// Returns an iterator over the equivalence classes in this group.
356    pub fn iter(&self) -> impl Iterator<Item = &EquivalenceClass> {
357        self.classes.iter()
358    }
359
360    /// Adds the equality `left` = `right` to this equivalence group.
361    /// New equality conditions often arise after steps like `Filter(a = b)`,
362    /// `Alias(a, a as b)` etc.
363    pub fn add_equal_conditions(
364        &mut self,
365        left: &Arc<dyn PhysicalExpr>,
366        right: &Arc<dyn PhysicalExpr>,
367    ) {
368        let mut first_class = None;
369        let mut second_class = None;
370        for (idx, cls) in self.classes.iter().enumerate() {
371            if cls.contains(left) {
372                first_class = Some(idx);
373            }
374            if cls.contains(right) {
375                second_class = Some(idx);
376            }
377        }
378        match (first_class, second_class) {
379            (Some(mut first_idx), Some(mut second_idx)) => {
380                // If the given left and right sides belong to different classes,
381                // we should unify/bridge these classes.
382                if first_idx != second_idx {
383                    // By convention, make sure `second_idx` is larger than `first_idx`.
384                    if first_idx > second_idx {
385                        (first_idx, second_idx) = (second_idx, first_idx);
386                    }
387                    // Remove the class at `second_idx` and merge its values with
388                    // the class at `first_idx`. The convention above makes sure
389                    // that `first_idx` is still valid after removing `second_idx`.
390                    let other_class = self.classes.swap_remove(second_idx);
391                    self.classes[first_idx].extend(other_class);
392                }
393            }
394            (Some(group_idx), None) => {
395                // Right side is new, extend left side's class:
396                self.classes[group_idx].push(Arc::clone(right));
397            }
398            (None, Some(group_idx)) => {
399                // Left side is new, extend right side's class:
400                self.classes[group_idx].push(Arc::clone(left));
401            }
402            (None, None) => {
403                // None of the expressions is among existing classes.
404                // Create a new equivalence class and extend the group.
405                self.classes.push(EquivalenceClass::new(vec![
406                    Arc::clone(left),
407                    Arc::clone(right),
408                ]));
409            }
410        }
411    }
412
413    /// Removes redundant entries from this group.
414    fn remove_redundant_entries(&mut self) {
415        // Remove duplicate entries from each equivalence class:
416        self.classes.retain_mut(|cls| {
417            // Keep groups that have at least two entries as singleton class is
418            // meaningless (i.e. it contains no non-trivial information):
419            cls.len() > 1
420        });
421        // Unify/bridge groups that have common expressions:
422        self.bridge_classes()
423    }
424
425    /// This utility function unifies/bridges classes that have common expressions.
426    /// For example, assume that we have [`EquivalenceClass`]es `[a, b]` and `[b, c]`.
427    /// Since both classes contain `b`, columns `a`, `b` and `c` are actually all
428    /// equal and belong to one class. This utility converts merges such classes.
429    fn bridge_classes(&mut self) {
430        let mut idx = 0;
431        while idx < self.classes.len() {
432            let mut next_idx = idx + 1;
433            let start_size = self.classes[idx].len();
434            while next_idx < self.classes.len() {
435                if self.classes[idx].contains_any(&self.classes[next_idx]) {
436                    let extension = self.classes.swap_remove(next_idx);
437                    self.classes[idx].extend(extension);
438                } else {
439                    next_idx += 1;
440                }
441            }
442            if self.classes[idx].len() > start_size {
443                continue;
444            }
445            idx += 1;
446        }
447    }
448
449    /// Extends this equivalence group with the `other` equivalence group.
450    pub fn extend(&mut self, other: Self) {
451        self.classes.extend(other.classes);
452        self.remove_redundant_entries();
453    }
454
455    /// Normalizes the given physical expression according to this group.
456    /// The expression is replaced with the first expression in the equivalence
457    /// class it matches with (if any).
458    pub fn normalize_expr(&self, expr: Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr> {
459        expr.transform(|expr| {
460            for cls in self.iter() {
461                if cls.contains(&expr) {
462                    // The unwrap below is safe because the guard above ensures
463                    // that the class is not empty.
464                    return Ok(Transformed::yes(cls.canonical_expr().unwrap()));
465                }
466            }
467            Ok(Transformed::no(expr))
468        })
469        .data()
470        .unwrap()
471        // The unwrap above is safe because the closure always returns `Ok`.
472    }
473
474    /// Normalizes the given sort expression according to this group.
475    /// The underlying physical expression is replaced with the first expression
476    /// in the equivalence class it matches with (if any). If the underlying
477    /// expression does not belong to any equivalence class in this group, returns
478    /// the sort expression as is.
479    pub fn normalize_sort_expr(
480        &self,
481        mut sort_expr: PhysicalSortExpr,
482    ) -> PhysicalSortExpr {
483        sort_expr.expr = self.normalize_expr(sort_expr.expr);
484        sort_expr
485    }
486
487    /// Normalizes the given sort requirement according to this group.
488    /// The underlying physical expression is replaced with the first expression
489    /// in the equivalence class it matches with (if any). If the underlying
490    /// expression does not belong to any equivalence class in this group, returns
491    /// the given sort requirement as is.
492    pub fn normalize_sort_requirement(
493        &self,
494        mut sort_requirement: PhysicalSortRequirement,
495    ) -> PhysicalSortRequirement {
496        sort_requirement.expr = self.normalize_expr(sort_requirement.expr);
497        sort_requirement
498    }
499
500    /// This function applies the `normalize_expr` function for all expressions
501    /// in `exprs` and returns the corresponding normalized physical expressions.
502    pub fn normalize_exprs(
503        &self,
504        exprs: impl IntoIterator<Item = Arc<dyn PhysicalExpr>>,
505    ) -> Vec<Arc<dyn PhysicalExpr>> {
506        exprs
507            .into_iter()
508            .map(|expr| self.normalize_expr(expr))
509            .collect()
510    }
511
512    /// This function applies the `normalize_sort_expr` function for all sort
513    /// expressions in `sort_exprs` and returns the corresponding normalized
514    /// sort expressions.
515    pub fn normalize_sort_exprs(&self, sort_exprs: &LexOrdering) -> LexOrdering {
516        // Convert sort expressions to sort requirements:
517        let sort_reqs = LexRequirement::from(sort_exprs.clone());
518        // Normalize the requirements:
519        let normalized_sort_reqs = self.normalize_sort_requirements(&sort_reqs);
520        // Convert sort requirements back to sort expressions:
521        LexOrdering::from(normalized_sort_reqs)
522    }
523
524    /// This function applies the `normalize_sort_requirement` function for all
525    /// requirements in `sort_reqs` and returns the corresponding normalized
526    /// sort requirements.
527    pub fn normalize_sort_requirements(
528        &self,
529        sort_reqs: &LexRequirement,
530    ) -> LexRequirement {
531        LexRequirement::new(
532            sort_reqs
533                .iter()
534                .map(|sort_req| self.normalize_sort_requirement(sort_req.clone()))
535                .collect(),
536        )
537        .collapse()
538    }
539
540    /// Projects `expr` according to the given projection mapping.
541    /// If the resulting expression is invalid after projection, returns `None`.
542    pub fn project_expr(
543        &self,
544        mapping: &ProjectionMapping,
545        expr: &Arc<dyn PhysicalExpr>,
546    ) -> Option<Arc<dyn PhysicalExpr>> {
547        // First, we try to project expressions with an exact match. If we are
548        // unable to do this, we consult equivalence classes.
549        if let Some(target) = mapping.target_expr(expr) {
550            // If we match the source, we can project directly:
551            return Some(target);
552        } else {
553            // If the given expression is not inside the mapping, try to project
554            // expressions considering the equivalence classes.
555            for (source, target) in mapping.iter() {
556                // If we match an equivalent expression to `source`, then we can
557                // project. For example, if we have the mapping `(a as a1, a + c)`
558                // and the equivalence class `(a, b)`, expression `b` projects to `a1`.
559                if self
560                    .get_equivalence_class(source)
561                    .is_some_and(|group| group.contains(expr))
562                {
563                    return Some(Arc::clone(target));
564                }
565            }
566        }
567        // Project a non-leaf expression by projecting its children.
568        let children = expr.children();
569        if children.is_empty() {
570            // Leaf expression should be inside mapping.
571            return None;
572        }
573        children
574            .into_iter()
575            .map(|child| self.project_expr(mapping, child))
576            .collect::<Option<Vec<_>>>()
577            .map(|children| Arc::clone(expr).with_new_children(children).unwrap())
578    }
579
580    /// Projects this equivalence group according to the given projection mapping.
581    pub fn project(&self, mapping: &ProjectionMapping) -> Self {
582        let projected_classes = self.iter().filter_map(|cls| {
583            let new_class = cls
584                .iter()
585                .filter_map(|expr| self.project_expr(mapping, expr))
586                .collect::<Vec<_>>();
587            (new_class.len() > 1).then_some(EquivalenceClass::new(new_class))
588        });
589
590        // The key is the source expression, and the value is the equivalence
591        // class that contains the corresponding target expression.
592        let mut new_classes: IndexMap<_, _> = IndexMap::new();
593        for (source, target) in mapping.iter() {
594            // We need to find equivalent projected expressions. For example,
595            // consider a table with columns `[a, b, c]` with `a` == `b`, and
596            // projection `[a + c, b + c]`. To conclude that `a + c == b + c`,
597            // we first normalize all source expressions in the mapping, then
598            // merge all equivalent expressions into the classes.
599            let normalized_expr = self.normalize_expr(Arc::clone(source));
600            new_classes
601                .entry(normalized_expr)
602                .or_insert_with(EquivalenceClass::new_empty)
603                .push(Arc::clone(target));
604        }
605        // Only add equivalence classes with at least two members as singleton
606        // equivalence classes are meaningless.
607        let new_classes = new_classes
608            .into_iter()
609            .filter_map(|(_, cls)| (cls.len() > 1).then_some(cls));
610
611        let classes = projected_classes.chain(new_classes).collect();
612        Self::new(classes)
613    }
614
615    /// Returns the equivalence class containing `expr`. If no equivalence class
616    /// contains `expr`, returns `None`.
617    fn get_equivalence_class(
618        &self,
619        expr: &Arc<dyn PhysicalExpr>,
620    ) -> Option<&EquivalenceClass> {
621        self.iter().find(|cls| cls.contains(expr))
622    }
623
624    /// Combine equivalence groups of the given join children.
625    pub fn join(
626        &self,
627        right_equivalences: &Self,
628        join_type: &JoinType,
629        left_size: usize,
630        on: &[(PhysicalExprRef, PhysicalExprRef)],
631    ) -> Self {
632        match join_type {
633            JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => {
634                let mut result = Self::new(
635                    self.iter()
636                        .cloned()
637                        .chain(
638                            right_equivalences
639                                .iter()
640                                .map(|cls| cls.with_offset(left_size)),
641                        )
642                        .collect(),
643                );
644                // In we have an inner join, expressions in the "on" condition
645                // are equal in the resulting table.
646                if join_type == &JoinType::Inner {
647                    for (lhs, rhs) in on.iter() {
648                        let new_lhs = Arc::clone(lhs);
649                        // Rewrite rhs to point to the right side of the join:
650                        let new_rhs = Arc::clone(rhs)
651                            .transform(|expr| {
652                                if let Some(column) =
653                                    expr.as_any().downcast_ref::<Column>()
654                                {
655                                    let new_column = Arc::new(Column::new(
656                                        column.name(),
657                                        column.index() + left_size,
658                                    ))
659                                        as _;
660                                    return Ok(Transformed::yes(new_column));
661                                }
662
663                                Ok(Transformed::no(expr))
664                            })
665                            .data()
666                            .unwrap();
667                        result.add_equal_conditions(&new_lhs, &new_rhs);
668                    }
669                }
670                result
671            }
672            JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => self.clone(),
673            JoinType::RightSemi | JoinType::RightAnti => right_equivalences.clone(),
674        }
675    }
676
677    /// Checks if two expressions are equal either directly or through equivalence classes.
678    /// For complex expressions (e.g. a + b), checks that the expression trees are structurally
679    /// identical and their leaf nodes are equivalent either directly or through equivalence classes.
680    pub fn exprs_equal(
681        &self,
682        left: &Arc<dyn PhysicalExpr>,
683        right: &Arc<dyn PhysicalExpr>,
684    ) -> bool {
685        // Direct equality check
686        if left.eq(right) {
687            return true;
688        }
689
690        // Check if expressions are equivalent through equivalence classes
691        // We need to check both directions since expressions might be in different classes
692        if let Some(left_class) = self.get_equivalence_class(left) {
693            if left_class.contains(right) {
694                return true;
695            }
696        }
697        if let Some(right_class) = self.get_equivalence_class(right) {
698            if right_class.contains(left) {
699                return true;
700            }
701        }
702
703        // For non-leaf nodes, check structural equality
704        let left_children = left.children();
705        let right_children = right.children();
706
707        // If either expression is a leaf node and we haven't found equality yet,
708        // they must be different
709        if left_children.is_empty() || right_children.is_empty() {
710            return false;
711        }
712
713        // Type equality check through reflection
714        if left.as_any().type_id() != right.as_any().type_id() {
715            return false;
716        }
717
718        // Check if the number of children is the same
719        if left_children.len() != right_children.len() {
720            return false;
721        }
722
723        // Check if all children are equal
724        left_children
725            .into_iter()
726            .zip(right_children)
727            .all(|(left_child, right_child)| self.exprs_equal(left_child, right_child))
728    }
729
730    /// Return the inner classes of this equivalence group.
731    pub fn into_inner(self) -> Vec<EquivalenceClass> {
732        self.classes
733    }
734}
735
736impl IntoIterator for EquivalenceGroup {
737    type Item = EquivalenceClass;
738    type IntoIter = IntoIter<EquivalenceClass>;
739
740    fn into_iter(self) -> Self::IntoIter {
741        self.classes.into_iter()
742    }
743}
744
745impl Display for EquivalenceGroup {
746    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
747        write!(f, "[")?;
748        let mut iter = self.iter();
749        if let Some(cls) = iter.next() {
750            write!(f, "{}", cls)?;
751        }
752        for cls in iter {
753            write!(f, ", {}", cls)?;
754        }
755        write!(f, "]")
756    }
757}
758
759#[cfg(test)]
760mod tests {
761    use super::*;
762    use crate::equivalence::tests::create_test_params;
763    use crate::expressions::{binary, col, lit, BinaryExpr, Literal};
764    use arrow::datatypes::{DataType, Field, Schema};
765
766    use datafusion_common::{Result, ScalarValue};
767    use datafusion_expr::Operator;
768
769    #[test]
770    fn test_bridge_groups() -> Result<()> {
771        // First entry in the tuple is argument, second entry is the bridged result
772        let test_cases = vec![
773            // ------- TEST CASE 1 -----------//
774            (
775                vec![vec![1, 2, 3], vec![2, 4, 5], vec![11, 12, 9], vec![7, 6, 5]],
776                // Expected is compared with set equality. Order of the specific results may change.
777                vec![vec![1, 2, 3, 4, 5, 6, 7], vec![9, 11, 12]],
778            ),
779            // ------- TEST CASE 2 -----------//
780            (
781                vec![vec![1, 2, 3], vec![3, 4, 5], vec![9, 8, 7], vec![7, 6, 5]],
782                // Expected
783                vec![vec![1, 2, 3, 4, 5, 6, 7, 8, 9]],
784            ),
785        ];
786        for (entries, expected) in test_cases {
787            let entries = entries
788                .into_iter()
789                .map(|entry| entry.into_iter().map(lit).collect::<Vec<_>>())
790                .map(EquivalenceClass::new)
791                .collect::<Vec<_>>();
792            let expected = expected
793                .into_iter()
794                .map(|entry| entry.into_iter().map(lit).collect::<Vec<_>>())
795                .map(EquivalenceClass::new)
796                .collect::<Vec<_>>();
797            let mut eq_groups = EquivalenceGroup::new(entries.clone());
798            eq_groups.bridge_classes();
799            let eq_groups = eq_groups.classes;
800            let err_msg = format!(
801                "error in test entries: {:?}, expected: {:?}, actual:{:?}",
802                entries, expected, eq_groups
803            );
804            assert_eq!(eq_groups.len(), expected.len(), "{}", err_msg);
805            for idx in 0..eq_groups.len() {
806                assert_eq!(&eq_groups[idx], &expected[idx], "{}", err_msg);
807            }
808        }
809        Ok(())
810    }
811
812    #[test]
813    fn test_remove_redundant_entries_eq_group() -> Result<()> {
814        let entries = [
815            EquivalenceClass::new(vec![lit(1), lit(1), lit(2)]),
816            // This group is meaningless should be removed
817            EquivalenceClass::new(vec![lit(3), lit(3)]),
818            EquivalenceClass::new(vec![lit(4), lit(5), lit(6)]),
819        ];
820        // Given equivalences classes are not in succinct form.
821        // Expected form is the most plain representation that is functionally same.
822        let expected = [
823            EquivalenceClass::new(vec![lit(1), lit(2)]),
824            EquivalenceClass::new(vec![lit(4), lit(5), lit(6)]),
825        ];
826        let mut eq_groups = EquivalenceGroup::new(entries.to_vec());
827        eq_groups.remove_redundant_entries();
828
829        let eq_groups = eq_groups.classes;
830        assert_eq!(eq_groups.len(), expected.len());
831        assert_eq!(eq_groups.len(), 2);
832
833        assert_eq!(eq_groups[0], expected[0]);
834        assert_eq!(eq_groups[1], expected[1]);
835        Ok(())
836    }
837
838    #[test]
839    fn test_schema_normalize_expr_with_equivalence() -> Result<()> {
840        let col_a = &Column::new("a", 0);
841        let col_b = &Column::new("b", 1);
842        let col_c = &Column::new("c", 2);
843        // Assume that column a and c are aliases.
844        let (_test_schema, eq_properties) = create_test_params()?;
845
846        let col_a_expr = Arc::new(col_a.clone()) as Arc<dyn PhysicalExpr>;
847        let col_b_expr = Arc::new(col_b.clone()) as Arc<dyn PhysicalExpr>;
848        let col_c_expr = Arc::new(col_c.clone()) as Arc<dyn PhysicalExpr>;
849        // Test cases for equivalence normalization,
850        // First entry in the tuple is argument, second entry is expected result after normalization.
851        let expressions = vec![
852            // Normalized version of the column a and c should go to a
853            // (by convention all the expressions inside equivalence class are mapped to the first entry
854            // in this case a is the first entry in the equivalence class.)
855            (&col_a_expr, &col_a_expr),
856            (&col_c_expr, &col_a_expr),
857            // Cannot normalize column b
858            (&col_b_expr, &col_b_expr),
859        ];
860        let eq_group = eq_properties.eq_group();
861        for (expr, expected_eq) in expressions {
862            assert!(
863                expected_eq.eq(&eq_group.normalize_expr(Arc::clone(expr))),
864                "error in test: expr: {expr:?}"
865            );
866        }
867
868        Ok(())
869    }
870
871    #[test]
872    fn test_contains_any() {
873        let lit_true = Arc::new(Literal::new(ScalarValue::Boolean(Some(true))))
874            as Arc<dyn PhysicalExpr>;
875        let lit_false = Arc::new(Literal::new(ScalarValue::Boolean(Some(false))))
876            as Arc<dyn PhysicalExpr>;
877        let lit2 =
878            Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc<dyn PhysicalExpr>;
879        let lit1 =
880            Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
881        let col_b_expr = Arc::new(Column::new("b", 1)) as Arc<dyn PhysicalExpr>;
882
883        let cls1 =
884            EquivalenceClass::new(vec![Arc::clone(&lit_true), Arc::clone(&lit_false)]);
885        let cls2 =
886            EquivalenceClass::new(vec![Arc::clone(&lit_true), Arc::clone(&col_b_expr)]);
887        let cls3 = EquivalenceClass::new(vec![Arc::clone(&lit2), Arc::clone(&lit1)]);
888
889        // lit_true is common
890        assert!(cls1.contains_any(&cls2));
891        // there is no common entry
892        assert!(!cls1.contains_any(&cls3));
893        assert!(!cls2.contains_any(&cls3));
894    }
895
896    #[test]
897    fn test_exprs_equal() -> Result<()> {
898        struct TestCase {
899            left: Arc<dyn PhysicalExpr>,
900            right: Arc<dyn PhysicalExpr>,
901            expected: bool,
902            description: &'static str,
903        }
904
905        // Create test columns
906        let col_a = Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>;
907        let col_b = Arc::new(Column::new("b", 1)) as Arc<dyn PhysicalExpr>;
908        let col_x = Arc::new(Column::new("x", 2)) as Arc<dyn PhysicalExpr>;
909        let col_y = Arc::new(Column::new("y", 3)) as Arc<dyn PhysicalExpr>;
910
911        // Create test literals
912        let lit_1 =
913            Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
914        let lit_2 =
915            Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc<dyn PhysicalExpr>;
916
917        // Create equivalence group with classes (a = x) and (b = y)
918        let eq_group = EquivalenceGroup::new(vec![
919            EquivalenceClass::new(vec![Arc::clone(&col_a), Arc::clone(&col_x)]),
920            EquivalenceClass::new(vec![Arc::clone(&col_b), Arc::clone(&col_y)]),
921        ]);
922
923        let test_cases = vec![
924            // Basic equality tests
925            TestCase {
926                left: Arc::clone(&col_a),
927                right: Arc::clone(&col_a),
928                expected: true,
929                description: "Same column should be equal",
930            },
931            // Equivalence class tests
932            TestCase {
933                left: Arc::clone(&col_a),
934                right: Arc::clone(&col_x),
935                expected: true,
936                description: "Columns in same equivalence class should be equal",
937            },
938            TestCase {
939                left: Arc::clone(&col_b),
940                right: Arc::clone(&col_y),
941                expected: true,
942                description: "Columns in same equivalence class should be equal",
943            },
944            TestCase {
945                left: Arc::clone(&col_a),
946                right: Arc::clone(&col_b),
947                expected: false,
948                description:
949                    "Columns in different equivalence classes should not be equal",
950            },
951            // Literal tests
952            TestCase {
953                left: Arc::clone(&lit_1),
954                right: Arc::clone(&lit_1),
955                expected: true,
956                description: "Same literal should be equal",
957            },
958            TestCase {
959                left: Arc::clone(&lit_1),
960                right: Arc::clone(&lit_2),
961                expected: false,
962                description: "Different literals should not be equal",
963            },
964            // Complex expression tests
965            TestCase {
966                left: Arc::new(BinaryExpr::new(
967                    Arc::clone(&col_a),
968                    Operator::Plus,
969                    Arc::clone(&col_b),
970                )) as Arc<dyn PhysicalExpr>,
971                right: Arc::new(BinaryExpr::new(
972                    Arc::clone(&col_x),
973                    Operator::Plus,
974                    Arc::clone(&col_y),
975                )) as Arc<dyn PhysicalExpr>,
976                expected: true,
977                description:
978                    "Binary expressions with equivalent operands should be equal",
979            },
980            TestCase {
981                left: Arc::new(BinaryExpr::new(
982                    Arc::clone(&col_a),
983                    Operator::Plus,
984                    Arc::clone(&col_b),
985                )) as Arc<dyn PhysicalExpr>,
986                right: Arc::new(BinaryExpr::new(
987                    Arc::clone(&col_x),
988                    Operator::Plus,
989                    Arc::clone(&col_a),
990                )) as Arc<dyn PhysicalExpr>,
991                expected: false,
992                description:
993                    "Binary expressions with non-equivalent operands should not be equal",
994            },
995            TestCase {
996                left: Arc::new(BinaryExpr::new(
997                    Arc::clone(&col_a),
998                    Operator::Plus,
999                    Arc::clone(&lit_1),
1000                )) as Arc<dyn PhysicalExpr>,
1001                right: Arc::new(BinaryExpr::new(
1002                    Arc::clone(&col_x),
1003                    Operator::Plus,
1004                    Arc::clone(&lit_1),
1005                )) as Arc<dyn PhysicalExpr>,
1006                expected: true,
1007                description: "Binary expressions with equivalent column and same literal should be equal",
1008            },
1009            TestCase {
1010                left: Arc::new(BinaryExpr::new(
1011                    Arc::new(BinaryExpr::new(
1012                        Arc::clone(&col_a),
1013                        Operator::Plus,
1014                        Arc::clone(&col_b),
1015                    )),
1016                    Operator::Multiply,
1017                    Arc::clone(&lit_1),
1018                )) as Arc<dyn PhysicalExpr>,
1019                right: Arc::new(BinaryExpr::new(
1020                    Arc::new(BinaryExpr::new(
1021                        Arc::clone(&col_x),
1022                        Operator::Plus,
1023                        Arc::clone(&col_y),
1024                    )),
1025                    Operator::Multiply,
1026                    Arc::clone(&lit_1),
1027                )) as Arc<dyn PhysicalExpr>,
1028                expected: true,
1029                description: "Nested binary expressions with equivalent operands should be equal",
1030            },
1031        ];
1032
1033        for TestCase {
1034            left,
1035            right,
1036            expected,
1037            description,
1038        } in test_cases
1039        {
1040            let actual = eq_group.exprs_equal(&left, &right);
1041            assert_eq!(
1042                actual, expected,
1043                "{}: Failed comparing {:?} and {:?}, expected {}, got {}",
1044                description, left, right, expected, actual
1045            );
1046        }
1047
1048        Ok(())
1049    }
1050
1051    #[test]
1052    fn test_project_classes() -> Result<()> {
1053        // - columns: [a, b, c].
1054        // - "a" and "b" in the same equivalence class.
1055        // - then after a+c, b+c projection col(0) and col(1) must be
1056        // in the same class too.
1057        let schema = Arc::new(Schema::new(vec![
1058            Field::new("a", DataType::Int32, false),
1059            Field::new("b", DataType::Int32, false),
1060            Field::new("c", DataType::Int32, false),
1061        ]));
1062        let mut group = EquivalenceGroup::empty();
1063        group.add_equal_conditions(&col("a", &schema)?, &col("b", &schema)?);
1064
1065        let projected_schema = Arc::new(Schema::new(vec![
1066            Field::new("a+c", DataType::Int32, false),
1067            Field::new("b+c", DataType::Int32, false),
1068        ]));
1069
1070        let mapping = ProjectionMapping {
1071            map: vec![
1072                (
1073                    binary(
1074                        col("a", &schema)?,
1075                        Operator::Plus,
1076                        col("c", &schema)?,
1077                        &schema,
1078                    )?,
1079                    col("a+c", &projected_schema)?,
1080                ),
1081                (
1082                    binary(
1083                        col("b", &schema)?,
1084                        Operator::Plus,
1085                        col("c", &schema)?,
1086                        &schema,
1087                    )?,
1088                    col("b+c", &projected_schema)?,
1089                ),
1090            ],
1091        };
1092
1093        let projected = group.project(&mapping);
1094
1095        assert!(!projected.is_empty());
1096        let first_normalized = projected.normalize_expr(col("a+c", &projected_schema)?);
1097        let second_normalized = projected.normalize_expr(col("b+c", &projected_schema)?);
1098
1099        assert!(first_normalized.eq(&second_normalized));
1100
1101        Ok(())
1102    }
1103}