1use std::borrow::Borrow;
25use std::pin::Pin;
26use std::task::{Context, Poll};
27use std::{any::Any, sync::Arc};
28
29use super::{
30 metrics::{ExecutionPlanMetricsSet, MetricsSet},
31 ColumnStatistics, DisplayAs, DisplayFormatType, ExecutionPlan,
32 ExecutionPlanProperties, Partitioning, PlanProperties, RecordBatchStream,
33 SendableRecordBatchStream, Statistics,
34};
35use crate::execution_plan::{
36 boundedness_from_children, emission_type_from_children, InvariantLevel,
37};
38use crate::metrics::BaselineMetrics;
39use crate::projection::{make_with_child, ProjectionExec};
40use crate::stream::ObservedStream;
41
42use arrow::datatypes::{Field, Schema, SchemaRef};
43use arrow::record_batch::RecordBatch;
44use datafusion_common::stats::Precision;
45use datafusion_common::{exec_err, internal_err, DataFusionError, Result};
46use datafusion_execution::TaskContext;
47use datafusion_physical_expr::{calculate_union, EquivalenceProperties};
48
49use futures::Stream;
50use itertools::Itertools;
51use log::{debug, trace, warn};
52use tokio::macros::support::thread_rng_n;
53
54#[derive(Debug, Clone)]
92pub struct UnionExec {
93 inputs: Vec<Arc<dyn ExecutionPlan>>,
95 metrics: ExecutionPlanMetricsSet,
97 cache: PlanProperties,
99}
100
101impl UnionExec {
102 pub fn new(inputs: Vec<Arc<dyn ExecutionPlan>>) -> Self {
104 let schema = union_schema(&inputs);
105 let cache = Self::compute_properties(&inputs, schema).unwrap();
111 UnionExec {
112 inputs,
113 metrics: ExecutionPlanMetricsSet::new(),
114 cache,
115 }
116 }
117
118 pub fn inputs(&self) -> &Vec<Arc<dyn ExecutionPlan>> {
120 &self.inputs
121 }
122
123 fn compute_properties(
125 inputs: &[Arc<dyn ExecutionPlan>],
126 schema: SchemaRef,
127 ) -> Result<PlanProperties> {
128 let children_eqps = inputs
130 .iter()
131 .map(|child| child.equivalence_properties().clone())
132 .collect::<Vec<_>>();
133 let eq_properties = calculate_union(children_eqps, schema)?;
134
135 let num_partitions = inputs
137 .iter()
138 .map(|plan| plan.output_partitioning().partition_count())
139 .sum();
140 let output_partitioning = Partitioning::UnknownPartitioning(num_partitions);
141 Ok(PlanProperties::new(
142 eq_properties,
143 output_partitioning,
144 emission_type_from_children(inputs),
145 boundedness_from_children(inputs),
146 ))
147 }
148}
149
150impl DisplayAs for UnionExec {
151 fn fmt_as(
152 &self,
153 t: DisplayFormatType,
154 f: &mut std::fmt::Formatter,
155 ) -> std::fmt::Result {
156 match t {
157 DisplayFormatType::Default | DisplayFormatType::Verbose => {
158 write!(f, "UnionExec")
159 }
160 }
161 }
162}
163
164impl ExecutionPlan for UnionExec {
165 fn name(&self) -> &'static str {
166 "UnionExec"
167 }
168
169 fn as_any(&self) -> &dyn Any {
171 self
172 }
173
174 fn properties(&self) -> &PlanProperties {
175 &self.cache
176 }
177
178 fn check_invariants(&self, _check: InvariantLevel) -> Result<()> {
179 (self.inputs().len() >= 2)
180 .then_some(())
181 .ok_or(DataFusionError::Internal(
182 "UnionExec should have at least 2 children".into(),
183 ))
184 }
185
186 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
187 self.inputs.iter().collect()
188 }
189
190 fn maintains_input_order(&self) -> Vec<bool> {
191 if let Some(output_ordering) = self.properties().output_ordering() {
200 self.inputs()
201 .iter()
202 .map(|child| {
203 if let Some(child_ordering) = child.output_ordering() {
204 output_ordering.len() == child_ordering.len()
205 } else {
206 false
207 }
208 })
209 .collect()
210 } else {
211 vec![false; self.inputs().len()]
212 }
213 }
214
215 fn with_new_children(
216 self: Arc<Self>,
217 children: Vec<Arc<dyn ExecutionPlan>>,
218 ) -> Result<Arc<dyn ExecutionPlan>> {
219 Ok(Arc::new(UnionExec::new(children)))
220 }
221
222 fn execute(
223 &self,
224 mut partition: usize,
225 context: Arc<TaskContext>,
226 ) -> Result<SendableRecordBatchStream> {
227 trace!("Start UnionExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id());
228 let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
229 let elapsed_compute = baseline_metrics.elapsed_compute().clone();
232 let _timer = elapsed_compute.timer(); for input in self.inputs.iter() {
236 if partition < input.output_partitioning().partition_count() {
238 let stream = input.execute(partition, context)?;
239 debug!("Found a Union partition to execute");
240 return Ok(Box::pin(ObservedStream::new(
241 stream,
242 baseline_metrics,
243 None,
244 )));
245 } else {
246 partition -= input.output_partitioning().partition_count();
247 }
248 }
249
250 warn!("Error in Union: Partition {} not found", partition);
251
252 exec_err!("Partition {partition} not found in Union")
253 }
254
255 fn metrics(&self) -> Option<MetricsSet> {
256 Some(self.metrics.clone_inner())
257 }
258
259 fn statistics(&self) -> Result<Statistics> {
260 let stats = self
261 .inputs
262 .iter()
263 .map(|stat| stat.statistics())
264 .collect::<Result<Vec<_>>>()?;
265
266 Ok(stats
267 .into_iter()
268 .reduce(stats_union)
269 .unwrap_or_else(|| Statistics::new_unknown(&self.schema())))
270 }
271
272 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
273 vec![false; self.children().len()]
274 }
275
276 fn supports_limit_pushdown(&self) -> bool {
277 true
278 }
279
280 fn try_swapping_with_projection(
284 &self,
285 projection: &ProjectionExec,
286 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
287 if projection.expr().len() >= projection.input().schema().fields().len() {
289 return Ok(None);
290 }
291
292 let new_children = self
293 .children()
294 .into_iter()
295 .map(|child| make_with_child(projection, child))
296 .collect::<Result<Vec<_>>>()?;
297
298 Ok(Some(Arc::new(UnionExec::new(new_children))))
299 }
300}
301
302#[derive(Debug, Clone)]
335pub struct InterleaveExec {
336 inputs: Vec<Arc<dyn ExecutionPlan>>,
338 metrics: ExecutionPlanMetricsSet,
340 cache: PlanProperties,
342}
343
344impl InterleaveExec {
345 pub fn try_new(inputs: Vec<Arc<dyn ExecutionPlan>>) -> Result<Self> {
347 if !can_interleave(inputs.iter()) {
348 return internal_err!(
349 "Not all InterleaveExec children have a consistent hash partitioning"
350 );
351 }
352 let cache = Self::compute_properties(&inputs);
353 Ok(InterleaveExec {
354 inputs,
355 metrics: ExecutionPlanMetricsSet::new(),
356 cache,
357 })
358 }
359
360 pub fn inputs(&self) -> &Vec<Arc<dyn ExecutionPlan>> {
362 &self.inputs
363 }
364
365 fn compute_properties(inputs: &[Arc<dyn ExecutionPlan>]) -> PlanProperties {
367 let schema = union_schema(inputs);
368 let eq_properties = EquivalenceProperties::new(schema);
369 let output_partitioning = inputs[0].output_partitioning().clone();
371 PlanProperties::new(
372 eq_properties,
373 output_partitioning,
374 emission_type_from_children(inputs),
375 boundedness_from_children(inputs),
376 )
377 }
378}
379
380impl DisplayAs for InterleaveExec {
381 fn fmt_as(
382 &self,
383 t: DisplayFormatType,
384 f: &mut std::fmt::Formatter,
385 ) -> std::fmt::Result {
386 match t {
387 DisplayFormatType::Default | DisplayFormatType::Verbose => {
388 write!(f, "InterleaveExec")
389 }
390 }
391 }
392}
393
394impl ExecutionPlan for InterleaveExec {
395 fn name(&self) -> &'static str {
396 "InterleaveExec"
397 }
398
399 fn as_any(&self) -> &dyn Any {
401 self
402 }
403
404 fn properties(&self) -> &PlanProperties {
405 &self.cache
406 }
407
408 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
409 self.inputs.iter().collect()
410 }
411
412 fn maintains_input_order(&self) -> Vec<bool> {
413 vec![false; self.inputs().len()]
414 }
415
416 fn with_new_children(
417 self: Arc<Self>,
418 children: Vec<Arc<dyn ExecutionPlan>>,
419 ) -> Result<Arc<dyn ExecutionPlan>> {
420 if !can_interleave(children.iter()) {
422 return internal_err!(
423 "Can not create InterleaveExec: new children can not be interleaved"
424 );
425 }
426 Ok(Arc::new(InterleaveExec::try_new(children)?))
427 }
428
429 fn execute(
430 &self,
431 partition: usize,
432 context: Arc<TaskContext>,
433 ) -> Result<SendableRecordBatchStream> {
434 trace!("Start InterleaveExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id());
435 let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
436 let elapsed_compute = baseline_metrics.elapsed_compute().clone();
439 let _timer = elapsed_compute.timer(); let mut input_stream_vec = vec![];
442 for input in self.inputs.iter() {
443 if partition < input.output_partitioning().partition_count() {
444 input_stream_vec.push(input.execute(partition, Arc::clone(&context))?);
445 } else {
446 break;
448 }
449 }
450 if input_stream_vec.len() == self.inputs.len() {
451 let stream = Box::pin(CombinedRecordBatchStream::new(
452 self.schema(),
453 input_stream_vec,
454 ));
455 return Ok(Box::pin(ObservedStream::new(
456 stream,
457 baseline_metrics,
458 None,
459 )));
460 }
461
462 warn!("Error in InterleaveExec: Partition {} not found", partition);
463
464 exec_err!("Partition {partition} not found in InterleaveExec")
465 }
466
467 fn metrics(&self) -> Option<MetricsSet> {
468 Some(self.metrics.clone_inner())
469 }
470
471 fn statistics(&self) -> Result<Statistics> {
472 let stats = self
473 .inputs
474 .iter()
475 .map(|stat| stat.statistics())
476 .collect::<Result<Vec<_>>>()?;
477
478 Ok(stats
479 .into_iter()
480 .reduce(stats_union)
481 .unwrap_or_else(|| Statistics::new_unknown(&self.schema())))
482 }
483
484 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
485 vec![false; self.children().len()]
486 }
487}
488
489pub fn can_interleave<T: Borrow<Arc<dyn ExecutionPlan>>>(
496 mut inputs: impl Iterator<Item = T>,
497) -> bool {
498 let Some(first) = inputs.next() else {
499 return false;
500 };
501
502 let reference = first.borrow().output_partitioning();
503 matches!(reference, Partitioning::Hash(_, _))
504 && inputs
505 .map(|plan| plan.borrow().output_partitioning().clone())
506 .all(|partition| partition == *reference)
507}
508
509fn union_schema(inputs: &[Arc<dyn ExecutionPlan>]) -> SchemaRef {
510 let first_schema = inputs[0].schema();
511
512 let fields = (0..first_schema.fields().len())
513 .map(|i| {
514 inputs
515 .iter()
516 .enumerate()
517 .map(|(input_idx, input)| {
518 let field = input.schema().field(i).clone();
519 let mut metadata = field.metadata().clone();
520
521 let other_metadatas = inputs
522 .iter()
523 .enumerate()
524 .filter(|(other_idx, _)| *other_idx != input_idx)
525 .flat_map(|(_, other_input)| {
526 other_input.schema().field(i).metadata().clone().into_iter()
527 });
528
529 metadata.extend(other_metadatas);
530 field.with_metadata(metadata)
531 })
532 .find_or_first(Field::is_nullable)
533 .unwrap()
536 })
537 .collect::<Vec<_>>();
538
539 let all_metadata_merged = inputs
540 .iter()
541 .flat_map(|i| i.schema().metadata().clone().into_iter())
542 .collect();
543
544 Arc::new(Schema::new_with_metadata(fields, all_metadata_merged))
545}
546
547struct CombinedRecordBatchStream {
549 schema: SchemaRef,
551 entries: Vec<SendableRecordBatchStream>,
553}
554
555impl CombinedRecordBatchStream {
556 pub fn new(schema: SchemaRef, entries: Vec<SendableRecordBatchStream>) -> Self {
558 Self { schema, entries }
559 }
560}
561
562impl RecordBatchStream for CombinedRecordBatchStream {
563 fn schema(&self) -> SchemaRef {
564 Arc::clone(&self.schema)
565 }
566}
567
568impl Stream for CombinedRecordBatchStream {
569 type Item = Result<RecordBatch>;
570
571 fn poll_next(
572 mut self: Pin<&mut Self>,
573 cx: &mut Context<'_>,
574 ) -> Poll<Option<Self::Item>> {
575 use Poll::*;
576
577 let start = thread_rng_n(self.entries.len() as u32) as usize;
578 let mut idx = start;
579
580 for _ in 0..self.entries.len() {
581 let stream = self.entries.get_mut(idx).unwrap();
582
583 match Pin::new(stream).poll_next(cx) {
584 Ready(Some(val)) => return Ready(Some(val)),
585 Ready(None) => {
586 self.entries.swap_remove(idx);
588
589 if idx == self.entries.len() {
592 idx = 0;
593 } else if idx < start && start <= self.entries.len() {
594 idx = idx.wrapping_add(1) % self.entries.len();
597 }
598 }
599 Pending => {
600 idx = idx.wrapping_add(1) % self.entries.len();
601 }
602 }
603 }
604
605 if self.entries.is_empty() {
607 Ready(None)
608 } else {
609 Pending
610 }
611 }
612}
613
614fn col_stats_union(
615 mut left: ColumnStatistics,
616 right: ColumnStatistics,
617) -> ColumnStatistics {
618 left.distinct_count = Precision::Absent;
619 left.min_value = left.min_value.min(&right.min_value);
620 left.max_value = left.max_value.max(&right.max_value);
621 left.sum_value = left.sum_value.add(&right.sum_value);
622 left.null_count = left.null_count.add(&right.null_count);
623
624 left
625}
626
627fn stats_union(mut left: Statistics, right: Statistics) -> Statistics {
628 left.num_rows = left.num_rows.add(&right.num_rows);
629 left.total_byte_size = left.total_byte_size.add(&right.total_byte_size);
630 left.column_statistics = left
631 .column_statistics
632 .into_iter()
633 .zip(right.column_statistics)
634 .map(|(a, b)| col_stats_union(a, b))
635 .collect::<Vec<_>>();
636 left
637}
638
639#[cfg(test)]
640mod tests {
641 use super::*;
642 use crate::collect;
643 use crate::test;
644 use crate::test::TestMemoryExec;
645
646 use arrow::compute::SortOptions;
647 use arrow::datatypes::DataType;
648 use datafusion_common::ScalarValue;
649 use datafusion_physical_expr::expressions::col;
650 use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
651 use datafusion_physical_expr_common::sort_expr::LexOrdering;
652
653 fn create_test_schema() -> Result<SchemaRef> {
655 let a = Field::new("a", DataType::Int32, true);
656 let b = Field::new("b", DataType::Int32, true);
657 let c = Field::new("c", DataType::Int32, true);
658 let d = Field::new("d", DataType::Int32, true);
659 let e = Field::new("e", DataType::Int32, true);
660 let f = Field::new("f", DataType::Int32, true);
661 let g = Field::new("g", DataType::Int32, true);
662 let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f, g]));
663
664 Ok(schema)
665 }
666
667 fn convert_to_sort_exprs(
669 in_data: &[(&Arc<dyn PhysicalExpr>, SortOptions)],
670 ) -> LexOrdering {
671 in_data
672 .iter()
673 .map(|(expr, options)| PhysicalSortExpr {
674 expr: Arc::clone(*expr),
675 options: *options,
676 })
677 .collect::<LexOrdering>()
678 }
679
680 #[tokio::test]
681 async fn test_union_partitions() -> Result<()> {
682 let task_ctx = Arc::new(TaskContext::default());
683
684 let csv = test::scan_partitioned(4);
686 let csv2 = test::scan_partitioned(5);
687
688 let union_exec = Arc::new(UnionExec::new(vec![csv, csv2]));
689
690 assert_eq!(
692 union_exec
693 .properties()
694 .output_partitioning()
695 .partition_count(),
696 9
697 );
698
699 let result: Vec<RecordBatch> = collect(union_exec, task_ctx).await?;
700 assert_eq!(result.len(), 9);
701
702 Ok(())
703 }
704
705 #[tokio::test]
706 async fn test_stats_union() {
707 let left = Statistics {
708 num_rows: Precision::Exact(5),
709 total_byte_size: Precision::Exact(23),
710 column_statistics: vec![
711 ColumnStatistics {
712 distinct_count: Precision::Exact(5),
713 max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
714 min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
715 sum_value: Precision::Exact(ScalarValue::Int64(Some(42))),
716 null_count: Precision::Exact(0),
717 },
718 ColumnStatistics {
719 distinct_count: Precision::Exact(1),
720 max_value: Precision::Exact(ScalarValue::from("x")),
721 min_value: Precision::Exact(ScalarValue::from("a")),
722 sum_value: Precision::Absent,
723 null_count: Precision::Exact(3),
724 },
725 ColumnStatistics {
726 distinct_count: Precision::Absent,
727 max_value: Precision::Exact(ScalarValue::Float32(Some(1.1))),
728 min_value: Precision::Exact(ScalarValue::Float32(Some(0.1))),
729 sum_value: Precision::Exact(ScalarValue::Float32(Some(42.0))),
730 null_count: Precision::Absent,
731 },
732 ],
733 };
734
735 let right = Statistics {
736 num_rows: Precision::Exact(7),
737 total_byte_size: Precision::Exact(29),
738 column_statistics: vec![
739 ColumnStatistics {
740 distinct_count: Precision::Exact(3),
741 max_value: Precision::Exact(ScalarValue::Int64(Some(34))),
742 min_value: Precision::Exact(ScalarValue::Int64(Some(1))),
743 sum_value: Precision::Exact(ScalarValue::Int64(Some(42))),
744 null_count: Precision::Exact(1),
745 },
746 ColumnStatistics {
747 distinct_count: Precision::Absent,
748 max_value: Precision::Exact(ScalarValue::from("c")),
749 min_value: Precision::Exact(ScalarValue::from("b")),
750 sum_value: Precision::Absent,
751 null_count: Precision::Absent,
752 },
753 ColumnStatistics {
754 distinct_count: Precision::Absent,
755 max_value: Precision::Absent,
756 min_value: Precision::Absent,
757 sum_value: Precision::Absent,
758 null_count: Precision::Absent,
759 },
760 ],
761 };
762
763 let result = stats_union(left, right);
764 let expected = Statistics {
765 num_rows: Precision::Exact(12),
766 total_byte_size: Precision::Exact(52),
767 column_statistics: vec![
768 ColumnStatistics {
769 distinct_count: Precision::Absent,
770 max_value: Precision::Exact(ScalarValue::Int64(Some(34))),
771 min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
772 sum_value: Precision::Exact(ScalarValue::Int64(Some(84))),
773 null_count: Precision::Exact(1),
774 },
775 ColumnStatistics {
776 distinct_count: Precision::Absent,
777 max_value: Precision::Exact(ScalarValue::from("x")),
778 min_value: Precision::Exact(ScalarValue::from("a")),
779 sum_value: Precision::Absent,
780 null_count: Precision::Absent,
781 },
782 ColumnStatistics {
783 distinct_count: Precision::Absent,
784 max_value: Precision::Absent,
785 min_value: Precision::Absent,
786 sum_value: Precision::Absent,
787 null_count: Precision::Absent,
788 },
789 ],
790 };
791
792 assert_eq!(result, expected);
793 }
794
795 #[tokio::test]
796 async fn test_union_equivalence_properties() -> Result<()> {
797 let schema = create_test_schema()?;
798 let col_a = &col("a", &schema)?;
799 let col_b = &col("b", &schema)?;
800 let col_c = &col("c", &schema)?;
801 let col_d = &col("d", &schema)?;
802 let col_e = &col("e", &schema)?;
803 let col_f = &col("f", &schema)?;
804 let options = SortOptions::default();
805 let test_cases = [
806 (
808 vec![
810 vec![(col_a, options), (col_b, options), (col_f, options)],
812 ],
813 vec![
815 vec![(col_a, options), (col_b, options), (col_c, options)],
817 vec![(col_a, options), (col_b, options), (col_f, options)],
819 ],
820 vec![
822 vec![(col_a, options), (col_b, options), (col_f, options)],
824 ],
825 ),
826 (
828 vec![
830 vec![(col_a, options), (col_b, options), (col_f, options)],
832 vec![(col_d, options)],
834 ],
835 vec![
837 vec![(col_a, options), (col_b, options), (col_c, options)],
839 vec![(col_e, options)],
841 ],
842 vec![
844 vec![(col_a, options), (col_b, options)],
846 ],
847 ),
848 ];
849
850 for (
851 test_idx,
852 (first_child_orderings, second_child_orderings, union_orderings),
853 ) in test_cases.iter().enumerate()
854 {
855 let first_orderings = first_child_orderings
856 .iter()
857 .map(|ordering| convert_to_sort_exprs(ordering))
858 .collect::<Vec<_>>();
859 let second_orderings = second_child_orderings
860 .iter()
861 .map(|ordering| convert_to_sort_exprs(ordering))
862 .collect::<Vec<_>>();
863 let union_expected_orderings = union_orderings
864 .iter()
865 .map(|ordering| convert_to_sort_exprs(ordering))
866 .collect::<Vec<_>>();
867 let child1 = Arc::new(TestMemoryExec::update_cache(Arc::new(
868 TestMemoryExec::try_new(&[], Arc::clone(&schema), None)?
869 .try_with_sort_information(first_orderings)?,
870 )));
871 let child2 = Arc::new(TestMemoryExec::update_cache(Arc::new(
872 TestMemoryExec::try_new(&[], Arc::clone(&schema), None)?
873 .try_with_sort_information(second_orderings)?,
874 )));
875
876 let mut union_expected_eq = EquivalenceProperties::new(Arc::clone(&schema));
877 union_expected_eq.add_new_orderings(union_expected_orderings);
878
879 let union = UnionExec::new(vec![child1, child2]);
880 let union_eq_properties = union.properties().equivalence_properties();
881 let err_msg = format!(
882 "Error in test id: {:?}, test case: {:?}",
883 test_idx, test_cases[test_idx]
884 );
885 assert_eq_properties_same(union_eq_properties, &union_expected_eq, err_msg);
886 }
887 Ok(())
888 }
889
890 fn assert_eq_properties_same(
891 lhs: &EquivalenceProperties,
892 rhs: &EquivalenceProperties,
893 err_msg: String,
894 ) {
895 let lhs_orderings = lhs.oeq_class();
897 let rhs_orderings = rhs.oeq_class();
898 assert_eq!(lhs_orderings.len(), rhs_orderings.len(), "{}", err_msg);
899 for rhs_ordering in rhs_orderings.iter() {
900 assert!(lhs_orderings.contains(rhs_ordering), "{}", err_msg);
901 }
902 }
903}