1use datafusion_common::{internal_err, Result};
19use datafusion_physical_expr_common::sort_expr::LexOrdering;
20use std::iter::Peekable;
21use std::sync::Arc;
22
23use crate::equivalence::class::AcrossPartitions;
24use crate::ConstExpr;
25
26use super::EquivalenceProperties;
27use crate::PhysicalSortExpr;
28use arrow::datatypes::SchemaRef;
29use std::slice::Iter;
30
31fn calculate_union_binary(
41 lhs: EquivalenceProperties,
42 mut rhs: EquivalenceProperties,
43) -> Result<EquivalenceProperties> {
44 if !rhs.schema.eq(&lhs.schema) {
46 rhs = rhs.with_new_schema(Arc::clone(&lhs.schema))?;
47 }
48
49 let constants = lhs
52 .constants()
53 .iter()
54 .filter_map(|lhs_const| {
55 rhs.constants()
57 .iter()
58 .find(|rhs_const| rhs_const.expr().eq(lhs_const.expr()))
59 .map(|rhs_const| {
60 let mut const_expr = ConstExpr::new(Arc::clone(lhs_const.expr()));
61
62 if let (
64 AcrossPartitions::Uniform(Some(lhs_val)),
65 AcrossPartitions::Uniform(Some(rhs_val)),
66 ) = (lhs_const.across_partitions(), rhs_const.across_partitions())
67 {
68 if lhs_val == rhs_val {
69 const_expr = const_expr.with_across_partitions(
70 AcrossPartitions::Uniform(Some(lhs_val)),
71 )
72 }
73 }
74 const_expr
75 })
76 })
77 .collect::<Vec<_>>();
78
79 let mut orderings = UnionEquivalentOrderingBuilder::new();
82 orderings.add_satisfied_orderings(lhs.normalized_oeq_class(), lhs.constants(), &rhs);
83 orderings.add_satisfied_orderings(rhs.normalized_oeq_class(), rhs.constants(), &lhs);
84 let orderings = orderings.build();
85
86 let mut eq_properties =
87 EquivalenceProperties::new(lhs.schema).with_constants(constants);
88
89 eq_properties.add_new_orderings(orderings);
90 Ok(eq_properties)
91}
92
93pub fn calculate_union(
98 eqps: Vec<EquivalenceProperties>,
99 schema: SchemaRef,
100) -> Result<EquivalenceProperties> {
101 let mut iter = eqps.into_iter();
104 let Some(mut acc) = iter.next() else {
105 return internal_err!(
106 "Cannot calculate EquivalenceProperties for a union with no inputs"
107 );
108 };
109
110 if !acc.schema.eq(&schema) {
112 acc = acc.with_new_schema(schema)?;
113 }
114 for props in iter {
116 acc = calculate_union_binary(acc, props)?;
117 }
118 Ok(acc)
119}
120
121#[derive(Debug)]
122enum AddedOrdering {
123 Yes,
125 No(LexOrdering),
127}
128
129#[derive(Debug)]
131struct UnionEquivalentOrderingBuilder {
132 orderings: Vec<LexOrdering>,
133}
134
135impl UnionEquivalentOrderingBuilder {
136 fn new() -> Self {
137 Self { orderings: vec![] }
138 }
139
140 fn add_satisfied_orderings(
153 &mut self,
154 orderings: impl IntoIterator<Item = LexOrdering>,
155 constants: &[ConstExpr],
156 properties: &EquivalenceProperties,
157 ) {
158 for mut ordering in orderings.into_iter() {
159 loop {
161 match self.try_add_ordering(ordering, constants, properties) {
162 AddedOrdering::Yes => break,
163 AddedOrdering::No(o) => {
164 ordering = o;
165 ordering.pop();
166 }
167 }
168 }
169 }
170 }
171
172 fn try_add_ordering(
182 &mut self,
183 ordering: LexOrdering,
184 constants: &[ConstExpr],
185 properties: &EquivalenceProperties,
186 ) -> AddedOrdering {
187 if ordering.is_empty() {
188 AddedOrdering::Yes
189 } else if properties.ordering_satisfy(ordering.as_ref()) {
190 self.orderings.push(ordering);
193 AddedOrdering::Yes
194 } else {
195 if self.try_find_augmented_ordering(&ordering, constants, properties) {
198 AddedOrdering::Yes
199 } else {
200 AddedOrdering::No(ordering)
201 }
202 }
203 }
204
205 fn try_find_augmented_ordering(
209 &mut self,
210 ordering: &LexOrdering,
211 constants: &[ConstExpr],
212 properties: &EquivalenceProperties,
213 ) -> bool {
214 if constants.is_empty() {
216 return false;
217 }
218 let start_num_orderings = self.orderings.len();
219
220 for existing_ordering in properties.oeq_class.iter() {
223 if let Some(augmented_ordering) = self.augment_ordering(
224 ordering,
225 constants,
226 existing_ordering,
227 &properties.constants,
228 ) {
229 if !augmented_ordering.is_empty() {
230 assert!(properties.ordering_satisfy(augmented_ordering.as_ref()));
231 self.orderings.push(augmented_ordering);
232 }
233 }
234 }
235
236 self.orderings.len() > start_num_orderings
237 }
238
239 fn augment_ordering(
244 &mut self,
245 ordering: &LexOrdering,
246 constants: &[ConstExpr],
247 existing_ordering: &LexOrdering,
248 existing_constants: &[ConstExpr],
249 ) -> Option<LexOrdering> {
250 let mut augmented_ordering = LexOrdering::default();
251 let mut sort_expr_iter = ordering.iter().peekable();
252 let mut existing_sort_expr_iter = existing_ordering.iter().peekable();
253
254 while sort_expr_iter.peek().is_some() || existing_sort_expr_iter.peek().is_some()
256 {
257 if let Some(expr) =
260 advance_if_match(&mut sort_expr_iter, &mut existing_sort_expr_iter)
261 {
262 augmented_ordering.push(expr);
263 } else if let Some(expr) =
264 advance_if_matches_constant(&mut sort_expr_iter, existing_constants)
265 {
266 augmented_ordering.push(expr);
267 } else if let Some(expr) =
268 advance_if_matches_constant(&mut existing_sort_expr_iter, constants)
269 {
270 augmented_ordering.push(expr);
271 } else {
272 break;
274 }
275 }
276
277 Some(augmented_ordering)
278 }
279
280 fn build(self) -> Vec<LexOrdering> {
281 self.orderings
282 }
283}
284
285fn advance_if_match(
292 iter1: &mut Peekable<Iter<PhysicalSortExpr>>,
293 iter2: &mut Peekable<Iter<PhysicalSortExpr>>,
294) -> Option<PhysicalSortExpr> {
295 if matches!((iter1.peek(), iter2.peek()), (Some(expr1), Some(expr2)) if expr1.eq(expr2))
296 {
297 iter1.next().unwrap();
298 iter2.next().cloned()
299 } else {
300 None
301 }
302}
303
304fn advance_if_matches_constant(
311 iter: &mut Peekable<Iter<PhysicalSortExpr>>,
312 constants: &[ConstExpr],
313) -> Option<PhysicalSortExpr> {
314 let expr = iter.peek()?;
315 let const_expr = constants.iter().find(|c| c.eq_expr(expr))?;
316 let found_expr = PhysicalSortExpr::new(Arc::clone(const_expr.expr()), expr.options);
317 iter.next();
318 Some(found_expr)
319}
320
321#[cfg(test)]
322mod tests {
323
324 use super::*;
325 use crate::equivalence::class::const_exprs_contains;
326 use crate::equivalence::tests::{create_test_schema, parse_sort_expr};
327 use crate::expressions::col;
328
329 use arrow::datatypes::{DataType, Field, Schema};
330 use datafusion_common::ScalarValue;
331
332 use itertools::Itertools;
333
334 #[test]
335 fn test_union_equivalence_properties_multi_children_1() {
336 let schema = create_test_schema().unwrap();
337 let schema2 = append_fields(&schema, "1");
338 let schema3 = append_fields(&schema, "2");
339 UnionEquivalenceTest::new(&schema)
340 .with_child_sort(vec![vec!["a", "b", "c"]], &schema)
342 .with_child_sort(vec![vec!["a1", "b1", "c1"]], &schema2)
344 .with_child_sort(vec![vec!["a2", "b2"]], &schema3)
346 .with_expected_sort(vec![vec!["a", "b"]])
347 .run()
348 }
349
350 #[test]
351 fn test_union_equivalence_properties_multi_children_2() {
352 let schema = create_test_schema().unwrap();
353 let schema2 = append_fields(&schema, "1");
354 let schema3 = append_fields(&schema, "2");
355 UnionEquivalenceTest::new(&schema)
356 .with_child_sort(vec![vec!["a", "b", "c"]], &schema)
358 .with_child_sort(vec![vec!["a1", "b1", "c1"]], &schema2)
360 .with_child_sort(vec![vec!["a2", "b2", "c2"]], &schema3)
362 .with_expected_sort(vec![vec!["a", "b", "c"]])
363 .run()
364 }
365
366 #[test]
367 fn test_union_equivalence_properties_multi_children_3() {
368 let schema = create_test_schema().unwrap();
369 let schema2 = append_fields(&schema, "1");
370 let schema3 = append_fields(&schema, "2");
371 UnionEquivalenceTest::new(&schema)
372 .with_child_sort(vec![vec!["a", "b"]], &schema)
374 .with_child_sort(vec![vec!["a1", "b1", "c1"]], &schema2)
376 .with_child_sort(vec![vec!["a2", "b2", "c2"]], &schema3)
378 .with_expected_sort(vec![vec!["a", "b"]])
379 .run()
380 }
381
382 #[test]
383 fn test_union_equivalence_properties_multi_children_4() {
384 let schema = create_test_schema().unwrap();
385 let schema2 = append_fields(&schema, "1");
386 let schema3 = append_fields(&schema, "2");
387 UnionEquivalenceTest::new(&schema)
388 .with_child_sort(vec![vec!["a", "b"]], &schema)
390 .with_child_sort(vec![vec!["a1", "b1"]], &schema2)
392 .with_child_sort(vec![vec!["b2", "c2"]], &schema3)
394 .with_expected_sort(vec![])
395 .run()
396 }
397
398 #[test]
399 fn test_union_equivalence_properties_multi_children_5() {
400 let schema = create_test_schema().unwrap();
401 let schema2 = append_fields(&schema, "1");
402 UnionEquivalenceTest::new(&schema)
403 .with_child_sort(vec![vec!["a", "b"], vec!["c"]], &schema)
405 .with_child_sort(vec![vec!["a1", "b1"], vec!["c1"]], &schema2)
407 .with_expected_sort(vec![vec!["a", "b"], vec!["c"]])
408 .run()
409 }
410
411 #[test]
412 fn test_union_equivalence_properties_constants_common_constants() {
413 let schema = create_test_schema().unwrap();
414 UnionEquivalenceTest::new(&schema)
415 .with_child_sort_and_const_exprs(
416 vec![vec!["a"]],
418 vec!["b", "c"],
419 &schema,
420 )
421 .with_child_sort_and_const_exprs(
422 vec![vec!["b"]],
424 vec!["a", "c"],
425 &schema,
426 )
427 .with_expected_sort_and_const_exprs(
428 vec![vec!["a"], vec!["b"]],
430 vec!["c"],
431 )
432 .run()
433 }
434
435 #[test]
436 fn test_union_equivalence_properties_constants_prefix() {
437 let schema = create_test_schema().unwrap();
438 UnionEquivalenceTest::new(&schema)
439 .with_child_sort_and_const_exprs(
440 vec![vec!["a"]],
442 vec![],
443 &schema,
444 )
445 .with_child_sort_and_const_exprs(
446 vec![vec!["a", "b"]],
448 vec![],
449 &schema,
450 )
451 .with_expected_sort_and_const_exprs(
452 vec![vec!["a"]],
454 vec![],
455 )
456 .run()
457 }
458
459 #[test]
460 fn test_union_equivalence_properties_constants_asc_desc_mismatch() {
461 let schema = create_test_schema().unwrap();
462 UnionEquivalenceTest::new(&schema)
463 .with_child_sort_and_const_exprs(
464 vec![vec!["a"]],
466 vec![],
467 &schema,
468 )
469 .with_child_sort_and_const_exprs(
470 vec![vec!["a DESC"]],
472 vec![],
473 &schema,
474 )
475 .with_expected_sort_and_const_exprs(
476 vec![],
478 vec![],
479 )
480 .run()
481 }
482
483 #[test]
484 fn test_union_equivalence_properties_constants_different_schemas() {
485 let schema = create_test_schema().unwrap();
486 let schema2 = append_fields(&schema, "1");
487 UnionEquivalenceTest::new(&schema)
488 .with_child_sort_and_const_exprs(
489 vec![vec!["a"]],
491 vec![],
492 &schema,
493 )
494 .with_child_sort_and_const_exprs(
495 vec![vec!["a1", "b1"]],
497 vec![],
498 &schema2,
499 )
500 .with_expected_sort_and_const_exprs(
501 vec![vec!["a"]],
506 vec![],
507 )
508 .run()
509 }
510
511 #[test]
512 fn test_union_equivalence_properties_constants_fill_gaps() {
513 let schema = create_test_schema().unwrap();
514 UnionEquivalenceTest::new(&schema)
515 .with_child_sort_and_const_exprs(
516 vec![vec!["a", "c"]],
518 vec!["b"],
519 &schema,
520 )
521 .with_child_sort_and_const_exprs(
522 vec![vec!["b", "c"]],
524 vec!["a"],
525 &schema,
526 )
527 .with_expected_sort_and_const_exprs(
528 vec![vec!["a", "b", "c"], vec!["b", "a", "c"]],
533 vec![],
534 )
535 .run()
536 }
537
538 #[test]
539 fn test_union_equivalence_properties_constants_no_fill_gaps() {
540 let schema = create_test_schema().unwrap();
541 UnionEquivalenceTest::new(&schema)
542 .with_child_sort_and_const_exprs(
543 vec![vec!["a", "c"]],
545 vec!["d"],
546 &schema,
547 )
548 .with_child_sort_and_const_exprs(
549 vec![vec!["b", "c"]],
551 vec!["a"],
552 &schema,
553 )
554 .with_expected_sort_and_const_exprs(
555 vec![vec!["a"]],
557 vec![],
558 )
559 .run()
560 }
561
562 #[test]
563 fn test_union_equivalence_properties_constants_fill_some_gaps() {
564 let schema = create_test_schema().unwrap();
565 UnionEquivalenceTest::new(&schema)
566 .with_child_sort_and_const_exprs(
567 vec![vec!["c"]],
569 vec!["a", "b"],
570 &schema,
571 )
572 .with_child_sort_and_const_exprs(
573 vec![vec!["a DESC", "b"]],
575 vec![],
576 &schema,
577 )
578 .with_expected_sort_and_const_exprs(
579 vec![vec!["a DESC", "b"]],
581 vec![],
582 )
583 .run()
584 }
585
586 #[test]
587 fn test_union_equivalence_properties_constants_fill_gaps_non_symmetric() {
588 let schema = create_test_schema().unwrap();
589 UnionEquivalenceTest::new(&schema)
590 .with_child_sort_and_const_exprs(
591 vec![vec!["a", "c"]],
593 vec!["b"],
594 &schema,
595 )
596 .with_child_sort_and_const_exprs(
597 vec![vec!["b DESC", "c"]],
599 vec!["a"],
600 &schema,
601 )
602 .with_expected_sort_and_const_exprs(
603 vec![vec!["a", "b DESC", "c"], vec!["b DESC", "a", "c"]],
608 vec![],
609 )
610 .run()
611 }
612
613 #[test]
614 fn test_union_equivalence_properties_constants_gap_fill_symmetric() {
615 let schema = create_test_schema().unwrap();
616 UnionEquivalenceTest::new(&schema)
617 .with_child_sort_and_const_exprs(
618 vec![vec!["a", "b", "d"]],
620 vec!["c"],
621 &schema,
622 )
623 .with_child_sort_and_const_exprs(
624 vec![vec!["a", "c", "d"]],
626 vec!["b"],
627 &schema,
628 )
629 .with_expected_sort_and_const_exprs(
630 vec![vec!["a", "c", "b", "d"], vec!["a", "b", "c", "d"]],
634 vec![],
635 )
636 .run()
637 }
638
639 #[test]
640 fn test_union_equivalence_properties_constants_gap_fill_and_common() {
641 let schema = create_test_schema().unwrap();
642 UnionEquivalenceTest::new(&schema)
643 .with_child_sort_and_const_exprs(
644 vec![vec!["a DESC", "d"]],
646 vec!["b", "c"],
647 &schema,
648 )
649 .with_child_sort_and_const_exprs(
650 vec![vec!["a DESC", "c", "d"]],
652 vec!["b"],
653 &schema,
654 )
655 .with_expected_sort_and_const_exprs(
656 vec![vec!["a DESC", "c", "d"]],
659 vec!["b"],
660 )
661 .run()
662 }
663
664 #[test]
665 fn test_union_equivalence_properties_constants_middle_desc() {
666 let schema = create_test_schema().unwrap();
667 UnionEquivalenceTest::new(&schema)
668 .with_child_sort_and_const_exprs(
669 vec![vec!["a", "b DESC", "d"]],
673 vec!["c"],
674 &schema,
675 )
676 .with_child_sort_and_const_exprs(
677 vec![vec!["a", "c", "d"]],
679 vec!["b"],
680 &schema,
681 )
682 .with_expected_sort_and_const_exprs(
683 vec![vec!["a", "c", "b DESC", "d"], vec!["a", "b DESC", "c", "d"]],
687 vec![],
688 )
689 .run()
690 }
691
692 #[derive(Debug)]
695 struct UnionEquivalenceTest {
696 output_schema: SchemaRef,
698 child_properties: Vec<EquivalenceProperties>,
700 expected_properties: Option<EquivalenceProperties>,
703 }
704
705 impl UnionEquivalenceTest {
706 fn new(output_schema: &SchemaRef) -> Self {
707 Self {
708 output_schema: Arc::clone(output_schema),
709 child_properties: vec![],
710 expected_properties: None,
711 }
712 }
713
714 fn with_child_sort(
718 mut self,
719 orderings: Vec<Vec<&str>>,
720 schema: &SchemaRef,
721 ) -> Self {
722 let properties = self.make_props(orderings, vec![], schema);
723 self.child_properties.push(properties);
724 self
725 }
726
727 fn with_child_sort_and_const_exprs(
733 mut self,
734 orderings: Vec<Vec<&str>>,
735 constants: Vec<&str>,
736 schema: &SchemaRef,
737 ) -> Self {
738 let properties = self.make_props(orderings, constants, schema);
739 self.child_properties.push(properties);
740 self
741 }
742
743 fn with_expected_sort(mut self, orderings: Vec<Vec<&str>>) -> Self {
747 let properties = self.make_props(orderings, vec![], &self.output_schema);
748 self.expected_properties = Some(properties);
749 self
750 }
751
752 fn with_expected_sort_and_const_exprs(
758 mut self,
759 orderings: Vec<Vec<&str>>,
760 constants: Vec<&str>,
761 ) -> Self {
762 let properties = self.make_props(orderings, constants, &self.output_schema);
763 self.expected_properties = Some(properties);
764 self
765 }
766
767 fn run(self) {
770 let Self {
771 output_schema,
772 child_properties,
773 expected_properties,
774 } = self;
775
776 let expected_properties =
777 expected_properties.expect("expected_properties not set");
778
779 for child_properties in child_properties
782 .iter()
783 .cloned()
784 .permutations(child_properties.len())
785 {
786 println!("--- permutation ---");
787 for c in &child_properties {
788 println!("{c}");
789 }
790 let actual_properties =
791 calculate_union(child_properties, Arc::clone(&output_schema))
792 .expect("failed to calculate union equivalence properties");
793 Self::assert_eq_properties_same(
794 &actual_properties,
795 &expected_properties,
796 format!(
797 "expected: {expected_properties:?}\nactual: {actual_properties:?}"
798 ),
799 );
800 }
801 }
802
803 fn assert_eq_properties_same(
804 lhs: &EquivalenceProperties,
805 rhs: &EquivalenceProperties,
806 err_msg: String,
807 ) {
808 let lhs_constants = lhs.constants();
810 let rhs_constants = rhs.constants();
811 for rhs_constant in rhs_constants {
812 assert!(
813 const_exprs_contains(lhs_constants, rhs_constant.expr()),
814 "{err_msg}\nlhs: {lhs}\nrhs: {rhs}"
815 );
816 }
817 assert_eq!(
818 lhs_constants.len(),
819 rhs_constants.len(),
820 "{err_msg}\nlhs: {lhs}\nrhs: {rhs}"
821 );
822
823 let lhs_orderings = lhs.oeq_class();
825 let rhs_orderings = rhs.oeq_class();
826 for rhs_ordering in rhs_orderings.iter() {
827 assert!(
828 lhs_orderings.contains(rhs_ordering),
829 "{err_msg}\nlhs: {lhs}\nrhs: {rhs}"
830 );
831 }
832 assert_eq!(
833 lhs_orderings.len(),
834 rhs_orderings.len(),
835 "{err_msg}\nlhs: {lhs}\nrhs: {rhs}"
836 );
837 }
838
839 fn make_props(
844 &self,
845 orderings: Vec<Vec<&str>>,
846 constants: Vec<&str>,
847 schema: &SchemaRef,
848 ) -> EquivalenceProperties {
849 let orderings = orderings
850 .iter()
851 .map(|ordering| {
852 ordering
853 .iter()
854 .map(|name| parse_sort_expr(name, schema))
855 .collect::<LexOrdering>()
856 })
857 .collect::<Vec<_>>();
858
859 let constants = constants
860 .iter()
861 .map(|col_name| ConstExpr::new(col(col_name, schema).unwrap()))
862 .collect::<Vec<_>>();
863
864 EquivalenceProperties::new_with_orderings(Arc::clone(schema), &orderings)
865 .with_constants(constants)
866 }
867 }
868
869 #[test]
870 fn test_union_constant_value_preservation() -> Result<()> {
871 let schema = Arc::new(Schema::new(vec![
872 Field::new("a", DataType::Int32, true),
873 Field::new("b", DataType::Int32, true),
874 ]));
875
876 let col_a = col("a", &schema)?;
877 let literal_10 = ScalarValue::Int32(Some(10));
878
879 let const_expr1 = ConstExpr::new(Arc::clone(&col_a))
881 .with_across_partitions(AcrossPartitions::Uniform(Some(literal_10.clone())));
882 let input1 = EquivalenceProperties::new(Arc::clone(&schema))
883 .with_constants(vec![const_expr1]);
884
885 let const_expr2 = ConstExpr::new(Arc::clone(&col_a))
887 .with_across_partitions(AcrossPartitions::Uniform(Some(literal_10.clone())));
888 let input2 = EquivalenceProperties::new(Arc::clone(&schema))
889 .with_constants(vec![const_expr2]);
890
891 let union_props = calculate_union(vec![input1, input2], schema)?;
893
894 let const_a = &union_props.constants()[0];
896 assert!(const_a.expr().eq(&col_a));
897 assert_eq!(
898 const_a.across_partitions(),
899 AcrossPartitions::Uniform(Some(literal_10))
900 );
901
902 Ok(())
903 }
904
905 fn append_fields(schema: &SchemaRef, text: &str) -> SchemaRef {
912 Arc::new(Schema::new(
913 schema
914 .fields()
915 .iter()
916 .map(|field| {
917 Field::new(
918 format!("{}{}", field.name(), text),
920 field.data_type().clone(),
921 field.is_nullable(),
922 )
923 })
924 .collect::<Vec<_>>(),
925 ))
926 }
927}