1use super::{add_offset_to_expr, ProjectionMapping};
19use crate::{
20 expressions::Column, LexOrdering, LexRequirement, PhysicalExpr, PhysicalExprRef,
21 PhysicalSortExpr, PhysicalSortRequirement,
22};
23use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
24use datafusion_common::{JoinType, ScalarValue};
25use datafusion_physical_expr_common::physical_expr::format_physical_expr_list;
26use std::fmt::Display;
27use std::sync::Arc;
28use std::vec::IntoIter;
29
30use indexmap::{IndexMap, IndexSet};
31
32#[derive(Debug, Clone)]
71pub struct ConstExpr {
72 expr: Arc<dyn PhysicalExpr>,
74 across_partitions: AcrossPartitions,
77}
78
79#[derive(PartialEq, Clone, Debug)]
80pub enum AcrossPartitions {
89 Heterogeneous,
90 Uniform(Option<ScalarValue>),
91}
92
93impl Default for AcrossPartitions {
94 fn default() -> Self {
95 Self::Heterogeneous
96 }
97}
98
99impl PartialEq for ConstExpr {
100 fn eq(&self, other: &Self) -> bool {
101 self.across_partitions == other.across_partitions && self.expr.eq(&other.expr)
102 }
103}
104
105impl ConstExpr {
106 pub fn new(expr: Arc<dyn PhysicalExpr>) -> Self {
111 Self {
112 expr,
113 across_partitions: Default::default(),
115 }
116 }
117
118 pub fn with_across_partitions(mut self, across_partitions: AcrossPartitions) -> Self {
122 self.across_partitions = across_partitions;
123 self
124 }
125
126 pub fn across_partitions(&self) -> AcrossPartitions {
130 self.across_partitions.clone()
131 }
132
133 pub fn expr(&self) -> &Arc<dyn PhysicalExpr> {
134 &self.expr
135 }
136
137 pub fn owned_expr(self) -> Arc<dyn PhysicalExpr> {
138 self.expr
139 }
140
141 pub fn map<F>(&self, f: F) -> Option<Self>
142 where
143 F: Fn(&Arc<dyn PhysicalExpr>) -> Option<Arc<dyn PhysicalExpr>>,
144 {
145 let maybe_expr = f(&self.expr);
146 maybe_expr.map(|expr| Self {
147 expr,
148 across_partitions: self.across_partitions.clone(),
149 })
150 }
151
152 pub fn eq_expr(&self, other: impl AsRef<dyn PhysicalExpr>) -> bool {
154 self.expr.as_ref() == other.as_ref()
155 }
156
157 pub fn format_list(input: &[ConstExpr]) -> impl Display + '_ {
159 struct DisplayableList<'a>(&'a [ConstExpr]);
160 impl Display for DisplayableList<'_> {
161 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
162 let mut first = true;
163 for const_expr in self.0 {
164 if first {
165 first = false;
166 } else {
167 write!(f, ",")?;
168 }
169 write!(f, "{}", const_expr)?;
170 }
171 Ok(())
172 }
173 }
174 DisplayableList(input)
175 }
176}
177
178impl Display for ConstExpr {
179 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
180 write!(f, "{}", self.expr)?;
181 match &self.across_partitions {
182 AcrossPartitions::Heterogeneous => {
183 write!(f, "(heterogeneous)")?;
184 }
185 AcrossPartitions::Uniform(value) => {
186 if let Some(val) = value {
187 write!(f, "(uniform: {})", val)?;
188 } else {
189 write!(f, "(uniform: unknown)")?;
190 }
191 }
192 }
193 Ok(())
194 }
195}
196
197impl From<Arc<dyn PhysicalExpr>> for ConstExpr {
198 fn from(expr: Arc<dyn PhysicalExpr>) -> Self {
199 Self::new(expr)
200 }
201}
202
203impl From<&Arc<dyn PhysicalExpr>> for ConstExpr {
204 fn from(expr: &Arc<dyn PhysicalExpr>) -> Self {
205 Self::new(Arc::clone(expr))
206 }
207}
208
209pub fn const_exprs_contains(
211 const_exprs: &[ConstExpr],
212 expr: &Arc<dyn PhysicalExpr>,
213) -> bool {
214 const_exprs
215 .iter()
216 .any(|const_expr| const_expr.expr.eq(expr))
217}
218
219#[derive(Debug, Clone)]
227pub struct EquivalenceClass {
228 exprs: IndexSet<Arc<dyn PhysicalExpr>>,
232}
233
234impl PartialEq for EquivalenceClass {
235 fn eq(&self, other: &Self) -> bool {
238 self.exprs.eq(&other.exprs)
239 }
240}
241
242impl EquivalenceClass {
243 pub fn new_empty() -> Self {
245 Self {
246 exprs: IndexSet::new(),
247 }
248 }
249
250 pub fn new(exprs: Vec<Arc<dyn PhysicalExpr>>) -> Self {
252 Self {
253 exprs: exprs.into_iter().collect(),
254 }
255 }
256
257 pub fn into_vec(self) -> Vec<Arc<dyn PhysicalExpr>> {
259 self.exprs.into_iter().collect()
260 }
261
262 fn canonical_expr(&self) -> Option<Arc<dyn PhysicalExpr>> {
265 self.exprs.iter().next().cloned()
266 }
267
268 pub fn push(&mut self, expr: Arc<dyn PhysicalExpr>) {
271 self.exprs.insert(expr);
272 }
273
274 pub fn extend(&mut self, other: Self) {
276 for expr in other.exprs {
277 self.push(expr);
279 }
280 }
281
282 pub fn contains(&self, expr: &Arc<dyn PhysicalExpr>) -> bool {
284 self.exprs.contains(expr)
285 }
286
287 pub fn contains_any(&self, other: &Self) -> bool {
289 self.exprs.iter().any(|e| other.contains(e))
290 }
291
292 pub fn len(&self) -> usize {
294 self.exprs.len()
295 }
296
297 pub fn is_empty(&self) -> bool {
299 self.exprs.is_empty()
300 }
301
302 pub fn iter(&self) -> impl Iterator<Item = &Arc<dyn PhysicalExpr>> {
304 self.exprs.iter()
305 }
306
307 pub fn with_offset(&self, offset: usize) -> Self {
310 let new_exprs = self
311 .exprs
312 .iter()
313 .cloned()
314 .map(|e| add_offset_to_expr(e, offset))
315 .collect();
316 Self::new(new_exprs)
317 }
318}
319
320impl Display for EquivalenceClass {
321 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
322 write!(f, "[{}]", format_physical_expr_list(&self.exprs))
323 }
324}
325
326#[derive(Debug, Clone)]
328pub struct EquivalenceGroup {
329 classes: Vec<EquivalenceClass>,
330}
331
332impl EquivalenceGroup {
333 pub fn empty() -> Self {
335 Self { classes: vec![] }
336 }
337
338 pub fn new(classes: Vec<EquivalenceClass>) -> Self {
340 let mut result = Self { classes };
341 result.remove_redundant_entries();
342 result
343 }
344
345 pub fn len(&self) -> usize {
347 self.classes.len()
348 }
349
350 pub fn is_empty(&self) -> bool {
352 self.len() == 0
353 }
354
355 pub fn iter(&self) -> impl Iterator<Item = &EquivalenceClass> {
357 self.classes.iter()
358 }
359
360 pub fn add_equal_conditions(
364 &mut self,
365 left: &Arc<dyn PhysicalExpr>,
366 right: &Arc<dyn PhysicalExpr>,
367 ) {
368 let mut first_class = None;
369 let mut second_class = None;
370 for (idx, cls) in self.classes.iter().enumerate() {
371 if cls.contains(left) {
372 first_class = Some(idx);
373 }
374 if cls.contains(right) {
375 second_class = Some(idx);
376 }
377 }
378 match (first_class, second_class) {
379 (Some(mut first_idx), Some(mut second_idx)) => {
380 if first_idx != second_idx {
383 if first_idx > second_idx {
385 (first_idx, second_idx) = (second_idx, first_idx);
386 }
387 let other_class = self.classes.swap_remove(second_idx);
391 self.classes[first_idx].extend(other_class);
392 }
393 }
394 (Some(group_idx), None) => {
395 self.classes[group_idx].push(Arc::clone(right));
397 }
398 (None, Some(group_idx)) => {
399 self.classes[group_idx].push(Arc::clone(left));
401 }
402 (None, None) => {
403 self.classes.push(EquivalenceClass::new(vec![
406 Arc::clone(left),
407 Arc::clone(right),
408 ]));
409 }
410 }
411 }
412
413 fn remove_redundant_entries(&mut self) {
415 self.classes.retain_mut(|cls| {
417 cls.len() > 1
420 });
421 self.bridge_classes()
423 }
424
425 fn bridge_classes(&mut self) {
430 let mut idx = 0;
431 while idx < self.classes.len() {
432 let mut next_idx = idx + 1;
433 let start_size = self.classes[idx].len();
434 while next_idx < self.classes.len() {
435 if self.classes[idx].contains_any(&self.classes[next_idx]) {
436 let extension = self.classes.swap_remove(next_idx);
437 self.classes[idx].extend(extension);
438 } else {
439 next_idx += 1;
440 }
441 }
442 if self.classes[idx].len() > start_size {
443 continue;
444 }
445 idx += 1;
446 }
447 }
448
449 pub fn extend(&mut self, other: Self) {
451 self.classes.extend(other.classes);
452 self.remove_redundant_entries();
453 }
454
455 pub fn normalize_expr(&self, expr: Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr> {
459 expr.transform(|expr| {
460 for cls in self.iter() {
461 if cls.contains(&expr) {
462 return Ok(Transformed::yes(cls.canonical_expr().unwrap()));
465 }
466 }
467 Ok(Transformed::no(expr))
468 })
469 .data()
470 .unwrap()
471 }
473
474 pub fn normalize_sort_expr(
480 &self,
481 mut sort_expr: PhysicalSortExpr,
482 ) -> PhysicalSortExpr {
483 sort_expr.expr = self.normalize_expr(sort_expr.expr);
484 sort_expr
485 }
486
487 pub fn normalize_sort_requirement(
493 &self,
494 mut sort_requirement: PhysicalSortRequirement,
495 ) -> PhysicalSortRequirement {
496 sort_requirement.expr = self.normalize_expr(sort_requirement.expr);
497 sort_requirement
498 }
499
500 pub fn normalize_exprs(
503 &self,
504 exprs: impl IntoIterator<Item = Arc<dyn PhysicalExpr>>,
505 ) -> Vec<Arc<dyn PhysicalExpr>> {
506 exprs
507 .into_iter()
508 .map(|expr| self.normalize_expr(expr))
509 .collect()
510 }
511
512 pub fn normalize_sort_exprs(&self, sort_exprs: &LexOrdering) -> LexOrdering {
516 let sort_reqs = LexRequirement::from(sort_exprs.clone());
518 let normalized_sort_reqs = self.normalize_sort_requirements(&sort_reqs);
520 LexOrdering::from(normalized_sort_reqs)
522 }
523
524 pub fn normalize_sort_requirements(
528 &self,
529 sort_reqs: &LexRequirement,
530 ) -> LexRequirement {
531 LexRequirement::new(
532 sort_reqs
533 .iter()
534 .map(|sort_req| self.normalize_sort_requirement(sort_req.clone()))
535 .collect(),
536 )
537 .collapse()
538 }
539
540 pub fn project_expr(
543 &self,
544 mapping: &ProjectionMapping,
545 expr: &Arc<dyn PhysicalExpr>,
546 ) -> Option<Arc<dyn PhysicalExpr>> {
547 if let Some(target) = mapping.target_expr(expr) {
550 return Some(target);
552 } else {
553 for (source, target) in mapping.iter() {
556 if self
560 .get_equivalence_class(source)
561 .is_some_and(|group| group.contains(expr))
562 {
563 return Some(Arc::clone(target));
564 }
565 }
566 }
567 let children = expr.children();
569 if children.is_empty() {
570 return None;
572 }
573 children
574 .into_iter()
575 .map(|child| self.project_expr(mapping, child))
576 .collect::<Option<Vec<_>>>()
577 .map(|children| Arc::clone(expr).with_new_children(children).unwrap())
578 }
579
580 pub fn project(&self, mapping: &ProjectionMapping) -> Self {
582 let projected_classes = self.iter().filter_map(|cls| {
583 let new_class = cls
584 .iter()
585 .filter_map(|expr| self.project_expr(mapping, expr))
586 .collect::<Vec<_>>();
587 (new_class.len() > 1).then_some(EquivalenceClass::new(new_class))
588 });
589
590 let mut new_classes: IndexMap<_, _> = IndexMap::new();
593 for (source, target) in mapping.iter() {
594 let normalized_expr = self.normalize_expr(Arc::clone(source));
600 new_classes
601 .entry(normalized_expr)
602 .or_insert_with(EquivalenceClass::new_empty)
603 .push(Arc::clone(target));
604 }
605 let new_classes = new_classes
608 .into_iter()
609 .filter_map(|(_, cls)| (cls.len() > 1).then_some(cls));
610
611 let classes = projected_classes.chain(new_classes).collect();
612 Self::new(classes)
613 }
614
615 fn get_equivalence_class(
618 &self,
619 expr: &Arc<dyn PhysicalExpr>,
620 ) -> Option<&EquivalenceClass> {
621 self.iter().find(|cls| cls.contains(expr))
622 }
623
624 pub fn join(
626 &self,
627 right_equivalences: &Self,
628 join_type: &JoinType,
629 left_size: usize,
630 on: &[(PhysicalExprRef, PhysicalExprRef)],
631 ) -> Self {
632 match join_type {
633 JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => {
634 let mut result = Self::new(
635 self.iter()
636 .cloned()
637 .chain(
638 right_equivalences
639 .iter()
640 .map(|cls| cls.with_offset(left_size)),
641 )
642 .collect(),
643 );
644 if join_type == &JoinType::Inner {
647 for (lhs, rhs) in on.iter() {
648 let new_lhs = Arc::clone(lhs);
649 let new_rhs = Arc::clone(rhs)
651 .transform(|expr| {
652 if let Some(column) =
653 expr.as_any().downcast_ref::<Column>()
654 {
655 let new_column = Arc::new(Column::new(
656 column.name(),
657 column.index() + left_size,
658 ))
659 as _;
660 return Ok(Transformed::yes(new_column));
661 }
662
663 Ok(Transformed::no(expr))
664 })
665 .data()
666 .unwrap();
667 result.add_equal_conditions(&new_lhs, &new_rhs);
668 }
669 }
670 result
671 }
672 JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => self.clone(),
673 JoinType::RightSemi | JoinType::RightAnti => right_equivalences.clone(),
674 }
675 }
676
677 pub fn exprs_equal(
681 &self,
682 left: &Arc<dyn PhysicalExpr>,
683 right: &Arc<dyn PhysicalExpr>,
684 ) -> bool {
685 if left.eq(right) {
687 return true;
688 }
689
690 if let Some(left_class) = self.get_equivalence_class(left) {
693 if left_class.contains(right) {
694 return true;
695 }
696 }
697 if let Some(right_class) = self.get_equivalence_class(right) {
698 if right_class.contains(left) {
699 return true;
700 }
701 }
702
703 let left_children = left.children();
705 let right_children = right.children();
706
707 if left_children.is_empty() || right_children.is_empty() {
710 return false;
711 }
712
713 if left.as_any().type_id() != right.as_any().type_id() {
715 return false;
716 }
717
718 if left_children.len() != right_children.len() {
720 return false;
721 }
722
723 left_children
725 .into_iter()
726 .zip(right_children)
727 .all(|(left_child, right_child)| self.exprs_equal(left_child, right_child))
728 }
729
730 pub fn into_inner(self) -> Vec<EquivalenceClass> {
732 self.classes
733 }
734}
735
736impl IntoIterator for EquivalenceGroup {
737 type Item = EquivalenceClass;
738 type IntoIter = IntoIter<EquivalenceClass>;
739
740 fn into_iter(self) -> Self::IntoIter {
741 self.classes.into_iter()
742 }
743}
744
745impl Display for EquivalenceGroup {
746 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
747 write!(f, "[")?;
748 let mut iter = self.iter();
749 if let Some(cls) = iter.next() {
750 write!(f, "{}", cls)?;
751 }
752 for cls in iter {
753 write!(f, ", {}", cls)?;
754 }
755 write!(f, "]")
756 }
757}
758
759#[cfg(test)]
760mod tests {
761 use super::*;
762 use crate::equivalence::tests::create_test_params;
763 use crate::expressions::{binary, col, lit, BinaryExpr, Literal};
764 use arrow::datatypes::{DataType, Field, Schema};
765
766 use datafusion_common::{Result, ScalarValue};
767 use datafusion_expr::Operator;
768
769 #[test]
770 fn test_bridge_groups() -> Result<()> {
771 let test_cases = vec![
773 (
775 vec![vec![1, 2, 3], vec![2, 4, 5], vec![11, 12, 9], vec![7, 6, 5]],
776 vec![vec![1, 2, 3, 4, 5, 6, 7], vec![9, 11, 12]],
778 ),
779 (
781 vec![vec![1, 2, 3], vec![3, 4, 5], vec![9, 8, 7], vec![7, 6, 5]],
782 vec![vec![1, 2, 3, 4, 5, 6, 7, 8, 9]],
784 ),
785 ];
786 for (entries, expected) in test_cases {
787 let entries = entries
788 .into_iter()
789 .map(|entry| entry.into_iter().map(lit).collect::<Vec<_>>())
790 .map(EquivalenceClass::new)
791 .collect::<Vec<_>>();
792 let expected = expected
793 .into_iter()
794 .map(|entry| entry.into_iter().map(lit).collect::<Vec<_>>())
795 .map(EquivalenceClass::new)
796 .collect::<Vec<_>>();
797 let mut eq_groups = EquivalenceGroup::new(entries.clone());
798 eq_groups.bridge_classes();
799 let eq_groups = eq_groups.classes;
800 let err_msg = format!(
801 "error in test entries: {:?}, expected: {:?}, actual:{:?}",
802 entries, expected, eq_groups
803 );
804 assert_eq!(eq_groups.len(), expected.len(), "{}", err_msg);
805 for idx in 0..eq_groups.len() {
806 assert_eq!(&eq_groups[idx], &expected[idx], "{}", err_msg);
807 }
808 }
809 Ok(())
810 }
811
812 #[test]
813 fn test_remove_redundant_entries_eq_group() -> Result<()> {
814 let entries = [
815 EquivalenceClass::new(vec![lit(1), lit(1), lit(2)]),
816 EquivalenceClass::new(vec![lit(3), lit(3)]),
818 EquivalenceClass::new(vec![lit(4), lit(5), lit(6)]),
819 ];
820 let expected = [
823 EquivalenceClass::new(vec![lit(1), lit(2)]),
824 EquivalenceClass::new(vec![lit(4), lit(5), lit(6)]),
825 ];
826 let mut eq_groups = EquivalenceGroup::new(entries.to_vec());
827 eq_groups.remove_redundant_entries();
828
829 let eq_groups = eq_groups.classes;
830 assert_eq!(eq_groups.len(), expected.len());
831 assert_eq!(eq_groups.len(), 2);
832
833 assert_eq!(eq_groups[0], expected[0]);
834 assert_eq!(eq_groups[1], expected[1]);
835 Ok(())
836 }
837
838 #[test]
839 fn test_schema_normalize_expr_with_equivalence() -> Result<()> {
840 let col_a = &Column::new("a", 0);
841 let col_b = &Column::new("b", 1);
842 let col_c = &Column::new("c", 2);
843 let (_test_schema, eq_properties) = create_test_params()?;
845
846 let col_a_expr = Arc::new(col_a.clone()) as Arc<dyn PhysicalExpr>;
847 let col_b_expr = Arc::new(col_b.clone()) as Arc<dyn PhysicalExpr>;
848 let col_c_expr = Arc::new(col_c.clone()) as Arc<dyn PhysicalExpr>;
849 let expressions = vec![
852 (&col_a_expr, &col_a_expr),
856 (&col_c_expr, &col_a_expr),
857 (&col_b_expr, &col_b_expr),
859 ];
860 let eq_group = eq_properties.eq_group();
861 for (expr, expected_eq) in expressions {
862 assert!(
863 expected_eq.eq(&eq_group.normalize_expr(Arc::clone(expr))),
864 "error in test: expr: {expr:?}"
865 );
866 }
867
868 Ok(())
869 }
870
871 #[test]
872 fn test_contains_any() {
873 let lit_true = Arc::new(Literal::new(ScalarValue::Boolean(Some(true))))
874 as Arc<dyn PhysicalExpr>;
875 let lit_false = Arc::new(Literal::new(ScalarValue::Boolean(Some(false))))
876 as Arc<dyn PhysicalExpr>;
877 let lit2 =
878 Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc<dyn PhysicalExpr>;
879 let lit1 =
880 Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
881 let col_b_expr = Arc::new(Column::new("b", 1)) as Arc<dyn PhysicalExpr>;
882
883 let cls1 =
884 EquivalenceClass::new(vec![Arc::clone(&lit_true), Arc::clone(&lit_false)]);
885 let cls2 =
886 EquivalenceClass::new(vec![Arc::clone(&lit_true), Arc::clone(&col_b_expr)]);
887 let cls3 = EquivalenceClass::new(vec![Arc::clone(&lit2), Arc::clone(&lit1)]);
888
889 assert!(cls1.contains_any(&cls2));
891 assert!(!cls1.contains_any(&cls3));
893 assert!(!cls2.contains_any(&cls3));
894 }
895
896 #[test]
897 fn test_exprs_equal() -> Result<()> {
898 struct TestCase {
899 left: Arc<dyn PhysicalExpr>,
900 right: Arc<dyn PhysicalExpr>,
901 expected: bool,
902 description: &'static str,
903 }
904
905 let col_a = Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>;
907 let col_b = Arc::new(Column::new("b", 1)) as Arc<dyn PhysicalExpr>;
908 let col_x = Arc::new(Column::new("x", 2)) as Arc<dyn PhysicalExpr>;
909 let col_y = Arc::new(Column::new("y", 3)) as Arc<dyn PhysicalExpr>;
910
911 let lit_1 =
913 Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
914 let lit_2 =
915 Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc<dyn PhysicalExpr>;
916
917 let eq_group = EquivalenceGroup::new(vec![
919 EquivalenceClass::new(vec![Arc::clone(&col_a), Arc::clone(&col_x)]),
920 EquivalenceClass::new(vec![Arc::clone(&col_b), Arc::clone(&col_y)]),
921 ]);
922
923 let test_cases = vec![
924 TestCase {
926 left: Arc::clone(&col_a),
927 right: Arc::clone(&col_a),
928 expected: true,
929 description: "Same column should be equal",
930 },
931 TestCase {
933 left: Arc::clone(&col_a),
934 right: Arc::clone(&col_x),
935 expected: true,
936 description: "Columns in same equivalence class should be equal",
937 },
938 TestCase {
939 left: Arc::clone(&col_b),
940 right: Arc::clone(&col_y),
941 expected: true,
942 description: "Columns in same equivalence class should be equal",
943 },
944 TestCase {
945 left: Arc::clone(&col_a),
946 right: Arc::clone(&col_b),
947 expected: false,
948 description:
949 "Columns in different equivalence classes should not be equal",
950 },
951 TestCase {
953 left: Arc::clone(&lit_1),
954 right: Arc::clone(&lit_1),
955 expected: true,
956 description: "Same literal should be equal",
957 },
958 TestCase {
959 left: Arc::clone(&lit_1),
960 right: Arc::clone(&lit_2),
961 expected: false,
962 description: "Different literals should not be equal",
963 },
964 TestCase {
966 left: Arc::new(BinaryExpr::new(
967 Arc::clone(&col_a),
968 Operator::Plus,
969 Arc::clone(&col_b),
970 )) as Arc<dyn PhysicalExpr>,
971 right: Arc::new(BinaryExpr::new(
972 Arc::clone(&col_x),
973 Operator::Plus,
974 Arc::clone(&col_y),
975 )) as Arc<dyn PhysicalExpr>,
976 expected: true,
977 description:
978 "Binary expressions with equivalent operands should be equal",
979 },
980 TestCase {
981 left: Arc::new(BinaryExpr::new(
982 Arc::clone(&col_a),
983 Operator::Plus,
984 Arc::clone(&col_b),
985 )) as Arc<dyn PhysicalExpr>,
986 right: Arc::new(BinaryExpr::new(
987 Arc::clone(&col_x),
988 Operator::Plus,
989 Arc::clone(&col_a),
990 )) as Arc<dyn PhysicalExpr>,
991 expected: false,
992 description:
993 "Binary expressions with non-equivalent operands should not be equal",
994 },
995 TestCase {
996 left: Arc::new(BinaryExpr::new(
997 Arc::clone(&col_a),
998 Operator::Plus,
999 Arc::clone(&lit_1),
1000 )) as Arc<dyn PhysicalExpr>,
1001 right: Arc::new(BinaryExpr::new(
1002 Arc::clone(&col_x),
1003 Operator::Plus,
1004 Arc::clone(&lit_1),
1005 )) as Arc<dyn PhysicalExpr>,
1006 expected: true,
1007 description: "Binary expressions with equivalent column and same literal should be equal",
1008 },
1009 TestCase {
1010 left: Arc::new(BinaryExpr::new(
1011 Arc::new(BinaryExpr::new(
1012 Arc::clone(&col_a),
1013 Operator::Plus,
1014 Arc::clone(&col_b),
1015 )),
1016 Operator::Multiply,
1017 Arc::clone(&lit_1),
1018 )) as Arc<dyn PhysicalExpr>,
1019 right: Arc::new(BinaryExpr::new(
1020 Arc::new(BinaryExpr::new(
1021 Arc::clone(&col_x),
1022 Operator::Plus,
1023 Arc::clone(&col_y),
1024 )),
1025 Operator::Multiply,
1026 Arc::clone(&lit_1),
1027 )) as Arc<dyn PhysicalExpr>,
1028 expected: true,
1029 description: "Nested binary expressions with equivalent operands should be equal",
1030 },
1031 ];
1032
1033 for TestCase {
1034 left,
1035 right,
1036 expected,
1037 description,
1038 } in test_cases
1039 {
1040 let actual = eq_group.exprs_equal(&left, &right);
1041 assert_eq!(
1042 actual, expected,
1043 "{}: Failed comparing {:?} and {:?}, expected {}, got {}",
1044 description, left, right, expected, actual
1045 );
1046 }
1047
1048 Ok(())
1049 }
1050
1051 #[test]
1052 fn test_project_classes() -> Result<()> {
1053 let schema = Arc::new(Schema::new(vec![
1058 Field::new("a", DataType::Int32, false),
1059 Field::new("b", DataType::Int32, false),
1060 Field::new("c", DataType::Int32, false),
1061 ]));
1062 let mut group = EquivalenceGroup::empty();
1063 group.add_equal_conditions(&col("a", &schema)?, &col("b", &schema)?);
1064
1065 let projected_schema = Arc::new(Schema::new(vec![
1066 Field::new("a+c", DataType::Int32, false),
1067 Field::new("b+c", DataType::Int32, false),
1068 ]));
1069
1070 let mapping = ProjectionMapping {
1071 map: vec![
1072 (
1073 binary(
1074 col("a", &schema)?,
1075 Operator::Plus,
1076 col("c", &schema)?,
1077 &schema,
1078 )?,
1079 col("a+c", &projected_schema)?,
1080 ),
1081 (
1082 binary(
1083 col("b", &schema)?,
1084 Operator::Plus,
1085 col("c", &schema)?,
1086 &schema,
1087 )?,
1088 col("b+c", &projected_schema)?,
1089 ),
1090 ],
1091 };
1092
1093 let projected = group.project(&mapping);
1094
1095 assert!(!projected.is_empty());
1096 let first_normalized = projected.normalize_expr(col("a+c", &projected_schema)?);
1097 let second_normalized = projected.normalize_expr(col("b+c", &projected_schema)?);
1098
1099 assert!(first_normalized.eq(&second_normalized));
1100
1101 Ok(())
1102 }
1103}