1use std::pin::Pin;
23use std::sync::Arc;
24use std::task::{Context, Poll};
25use std::{any::Any, vec};
26
27use super::common::SharedMemoryReservation;
28use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
29use super::{
30 DisplayAs, ExecutionPlanProperties, RecordBatchStream, SendableRecordBatchStream,
31};
32use crate::execution_plan::CardinalityEffect;
33use crate::hash_utils::create_hashes;
34use crate::metrics::BaselineMetrics;
35use crate::projection::{all_columns, make_with_child, update_expr, ProjectionExec};
36use crate::repartition::distributor_channels::{
37 channels, partition_aware_channels, DistributionReceiver, DistributionSender,
38};
39use crate::sorts::streaming_merge::StreamingMergeBuilder;
40use crate::stream::RecordBatchStreamAdapter;
41use crate::{DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, Statistics};
42
43use arrow::array::{PrimitiveArray, RecordBatch, RecordBatchOptions};
44use arrow::compute::take_arrays;
45use arrow::datatypes::{SchemaRef, UInt32Type};
46use datafusion_common::utils::transpose;
47use datafusion_common::HashMap;
48use datafusion_common::{not_impl_err, DataFusionError, Result};
49use datafusion_common_runtime::SpawnedTask;
50use datafusion_execution::memory_pool::MemoryConsumer;
51use datafusion_execution::TaskContext;
52use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr};
53use datafusion_physical_expr_common::sort_expr::LexOrdering;
54
55use futures::stream::Stream;
56use futures::{FutureExt, StreamExt, TryStreamExt};
57use log::trace;
58use parking_lot::Mutex;
59
60mod distributor_channels;
61
62type MaybeBatch = Option<Result<RecordBatch>>;
63type InputPartitionsToCurrentPartitionSender = Vec<DistributionSender<MaybeBatch>>;
64type InputPartitionsToCurrentPartitionReceiver = Vec<DistributionReceiver<MaybeBatch>>;
65
66#[derive(Debug)]
68struct RepartitionExecState {
69 channels: HashMap<
72 usize,
73 (
74 InputPartitionsToCurrentPartitionSender,
75 InputPartitionsToCurrentPartitionReceiver,
76 SharedMemoryReservation,
77 ),
78 >,
79
80 abort_helper: Arc<Vec<SpawnedTask<()>>>,
82}
83
84impl RepartitionExecState {
85 fn new(
86 input: Arc<dyn ExecutionPlan>,
87 partitioning: Partitioning,
88 metrics: ExecutionPlanMetricsSet,
89 preserve_order: bool,
90 name: String,
91 context: Arc<TaskContext>,
92 ) -> Self {
93 let num_input_partitions = input.output_partitioning().partition_count();
94 let num_output_partitions = partitioning.partition_count();
95
96 let (txs, rxs) = if preserve_order {
97 let (txs, rxs) =
98 partition_aware_channels(num_input_partitions, num_output_partitions);
99 let txs = transpose(txs);
101 let rxs = transpose(rxs);
102 (txs, rxs)
103 } else {
104 let (txs, rxs) = channels(num_output_partitions);
108 let txs = txs
110 .into_iter()
111 .map(|item| vec![item; num_input_partitions])
112 .collect::<Vec<_>>();
113 let rxs = rxs.into_iter().map(|item| vec![item]).collect::<Vec<_>>();
114 (txs, rxs)
115 };
116
117 let mut channels = HashMap::with_capacity(txs.len());
118 for (partition, (tx, rx)) in txs.into_iter().zip(rxs).enumerate() {
119 let reservation = Arc::new(Mutex::new(
120 MemoryConsumer::new(format!("{}[{partition}]", name))
121 .register(context.memory_pool()),
122 ));
123 channels.insert(partition, (tx, rx, reservation));
124 }
125
126 let mut spawned_tasks = Vec::with_capacity(num_input_partitions);
128 for i in 0..num_input_partitions {
129 let txs: HashMap<_, _> = channels
130 .iter()
131 .map(|(partition, (tx, _rx, reservation))| {
132 (*partition, (tx[i].clone(), Arc::clone(reservation)))
133 })
134 .collect();
135
136 let r_metrics = RepartitionMetrics::new(i, num_output_partitions, &metrics);
137
138 let input_task = SpawnedTask::spawn(RepartitionExec::pull_from_input(
139 Arc::clone(&input),
140 i,
141 txs.clone(),
142 partitioning.clone(),
143 r_metrics,
144 Arc::clone(&context),
145 ));
146
147 let wait_for_task = SpawnedTask::spawn(RepartitionExec::wait_for_task(
150 input_task,
151 txs.into_iter()
152 .map(|(partition, (tx, _reservation))| (partition, tx))
153 .collect(),
154 ));
155 spawned_tasks.push(wait_for_task);
156 }
157
158 Self {
159 channels,
160 abort_helper: Arc::new(spawned_tasks),
161 }
162 }
163}
164
165type LazyState = Arc<tokio::sync::OnceCell<Mutex<RepartitionExecState>>>;
178
179pub struct BatchPartitioner {
181 state: BatchPartitionerState,
182 timer: metrics::Time,
183}
184
185enum BatchPartitionerState {
186 Hash {
187 random_state: ahash::RandomState,
188 exprs: Vec<Arc<dyn PhysicalExpr>>,
189 num_partitions: usize,
190 hash_buffer: Vec<u64>,
191 },
192 RoundRobin {
193 num_partitions: usize,
194 next_idx: usize,
195 },
196}
197
198impl BatchPartitioner {
199 pub fn try_new(partitioning: Partitioning, timer: metrics::Time) -> Result<Self> {
203 let state = match partitioning {
204 Partitioning::RoundRobinBatch(num_partitions) => {
205 BatchPartitionerState::RoundRobin {
206 num_partitions,
207 next_idx: 0,
208 }
209 }
210 Partitioning::Hash(exprs, num_partitions) => BatchPartitionerState::Hash {
211 exprs,
212 num_partitions,
213 random_state: ahash::RandomState::with_seeds(0, 0, 0, 0),
215 hash_buffer: vec![],
216 },
217 other => return not_impl_err!("Unsupported repartitioning scheme {other:?}"),
218 };
219
220 Ok(Self { state, timer })
221 }
222
223 pub fn partition<F>(&mut self, batch: RecordBatch, mut f: F) -> Result<()>
233 where
234 F: FnMut(usize, RecordBatch) -> Result<()>,
235 {
236 self.partition_iter(batch)?.try_for_each(|res| match res {
237 Ok((partition, batch)) => f(partition, batch),
238 Err(e) => Err(e),
239 })
240 }
241
242 fn partition_iter(
248 &mut self,
249 batch: RecordBatch,
250 ) -> Result<impl Iterator<Item = Result<(usize, RecordBatch)>> + Send + '_> {
251 let it: Box<dyn Iterator<Item = Result<(usize, RecordBatch)>> + Send> =
252 match &mut self.state {
253 BatchPartitionerState::RoundRobin {
254 num_partitions,
255 next_idx,
256 } => {
257 let idx = *next_idx;
258 *next_idx = (*next_idx + 1) % *num_partitions;
259 Box::new(std::iter::once(Ok((idx, batch))))
260 }
261 BatchPartitionerState::Hash {
262 random_state,
263 exprs,
264 num_partitions: partitions,
265 hash_buffer,
266 } => {
267 let timer = self.timer.timer();
269
270 let arrays = exprs
271 .iter()
272 .map(|expr| expr.evaluate(&batch)?.into_array(batch.num_rows()))
273 .collect::<Result<Vec<_>>>()?;
274
275 hash_buffer.clear();
276 hash_buffer.resize(batch.num_rows(), 0);
277
278 create_hashes(&arrays, random_state, hash_buffer)?;
279
280 let mut indices: Vec<_> = (0..*partitions)
281 .map(|_| Vec::with_capacity(batch.num_rows()))
282 .collect();
283
284 for (index, hash) in hash_buffer.iter().enumerate() {
285 indices[(*hash % *partitions as u64) as usize].push(index as u32);
286 }
287
288 timer.done();
290
291 let partitioner_timer = &self.timer;
293 let it = indices
294 .into_iter()
295 .enumerate()
296 .filter_map(|(partition, indices)| {
297 let indices: PrimitiveArray<UInt32Type> = indices.into();
298 (!indices.is_empty()).then_some((partition, indices))
299 })
300 .map(move |(partition, indices)| {
301 let _timer = partitioner_timer.timer();
303
304 let columns = take_arrays(batch.columns(), &indices, None)?;
306
307 let mut options = RecordBatchOptions::new();
308 options = options.with_row_count(Some(indices.len()));
309 let batch = RecordBatch::try_new_with_options(
310 batch.schema(),
311 columns,
312 &options,
313 )
314 .unwrap();
315
316 Ok((partition, batch))
317 });
318
319 Box::new(it)
320 }
321 };
322
323 Ok(it)
324 }
325
326 fn num_partitions(&self) -> usize {
328 match self.state {
329 BatchPartitionerState::RoundRobin { num_partitions, .. } => num_partitions,
330 BatchPartitionerState::Hash { num_partitions, .. } => num_partitions,
331 }
332 }
333}
334
335#[derive(Debug, Clone)]
402pub struct RepartitionExec {
403 input: Arc<dyn ExecutionPlan>,
405 state: LazyState,
407 metrics: ExecutionPlanMetricsSet,
409 preserve_order: bool,
412 cache: PlanProperties,
414}
415
416#[derive(Debug, Clone)]
417struct RepartitionMetrics {
418 fetch_time: metrics::Time,
420 repartition_time: metrics::Time,
422 send_time: Vec<metrics::Time>,
426}
427
428impl RepartitionMetrics {
429 pub fn new(
430 input_partition: usize,
431 num_output_partitions: usize,
432 metrics: &ExecutionPlanMetricsSet,
433 ) -> Self {
434 let fetch_time =
436 MetricBuilder::new(metrics).subset_time("fetch_time", input_partition);
437
438 let repartition_time =
440 MetricBuilder::new(metrics).subset_time("repartition_time", input_partition);
441
442 let send_time = (0..num_output_partitions)
444 .map(|output_partition| {
445 let label =
446 metrics::Label::new("outputPartition", output_partition.to_string());
447 MetricBuilder::new(metrics)
448 .with_label(label)
449 .subset_time("send_time", input_partition)
450 })
451 .collect();
452
453 Self {
454 fetch_time,
455 repartition_time,
456 send_time,
457 }
458 }
459}
460
461impl RepartitionExec {
462 pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
464 &self.input
465 }
466
467 pub fn partitioning(&self) -> &Partitioning {
469 &self.cache.partitioning
470 }
471
472 pub fn preserve_order(&self) -> bool {
475 self.preserve_order
476 }
477
478 pub fn name(&self) -> &str {
480 "RepartitionExec"
481 }
482}
483
484impl DisplayAs for RepartitionExec {
485 fn fmt_as(
486 &self,
487 t: DisplayFormatType,
488 f: &mut std::fmt::Formatter,
489 ) -> std::fmt::Result {
490 match t {
491 DisplayFormatType::Default | DisplayFormatType::Verbose => {
492 write!(
493 f,
494 "{}: partitioning={}, input_partitions={}",
495 self.name(),
496 self.partitioning(),
497 self.input.output_partitioning().partition_count()
498 )?;
499
500 if self.preserve_order {
501 write!(f, ", preserve_order=true")?;
502 }
503
504 if let Some(sort_exprs) = self.sort_exprs() {
505 write!(f, ", sort_exprs={}", sort_exprs.clone())?;
506 }
507 Ok(())
508 }
509 }
510 }
511}
512
513impl ExecutionPlan for RepartitionExec {
514 fn name(&self) -> &'static str {
515 "RepartitionExec"
516 }
517
518 fn as_any(&self) -> &dyn Any {
520 self
521 }
522
523 fn properties(&self) -> &PlanProperties {
524 &self.cache
525 }
526
527 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
528 vec![&self.input]
529 }
530
531 fn with_new_children(
532 self: Arc<Self>,
533 mut children: Vec<Arc<dyn ExecutionPlan>>,
534 ) -> Result<Arc<dyn ExecutionPlan>> {
535 let mut repartition = RepartitionExec::try_new(
536 children.swap_remove(0),
537 self.partitioning().clone(),
538 )?;
539 if self.preserve_order {
540 repartition = repartition.with_preserve_order();
541 }
542 Ok(Arc::new(repartition))
543 }
544
545 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
546 vec![matches!(self.partitioning(), Partitioning::Hash(_, _))]
547 }
548
549 fn maintains_input_order(&self) -> Vec<bool> {
550 Self::maintains_input_order_helper(self.input(), self.preserve_order)
551 }
552
553 fn execute(
554 &self,
555 partition: usize,
556 context: Arc<TaskContext>,
557 ) -> Result<SendableRecordBatchStream> {
558 trace!(
559 "Start {}::execute for partition: {}",
560 self.name(),
561 partition
562 );
563
564 let lazy_state = Arc::clone(&self.state);
565 let input = Arc::clone(&self.input);
566 let partitioning = self.partitioning().clone();
567 let metrics = self.metrics.clone();
568 let preserve_order = self.preserve_order;
569 let name = self.name().to_owned();
570 let schema = self.schema();
571 let schema_captured = Arc::clone(&schema);
572
573 let sort_exprs = self.sort_exprs().cloned().unwrap_or_default();
575
576 let stream = futures::stream::once(async move {
577 let num_input_partitions = input.output_partitioning().partition_count();
578
579 let input_captured = Arc::clone(&input);
580 let metrics_captured = metrics.clone();
581 let name_captured = name.clone();
582 let context_captured = Arc::clone(&context);
583 let state = lazy_state
584 .get_or_init(|| async move {
585 Mutex::new(RepartitionExecState::new(
586 input_captured,
587 partitioning,
588 metrics_captured,
589 preserve_order,
590 name_captured,
591 context_captured,
592 ))
593 })
594 .await;
595
596 let (mut rx, reservation, abort_helper) = {
598 let mut state = state.lock();
600
601 let (_tx, rx, reservation) = state
604 .channels
605 .remove(&partition)
606 .expect("partition not used yet");
607
608 (rx, reservation, Arc::clone(&state.abort_helper))
609 };
610
611 trace!(
612 "Before returning stream in {}::execute for partition: {}",
613 name,
614 partition
615 );
616
617 if preserve_order {
618 let input_streams = rx
620 .into_iter()
621 .map(|receiver| {
622 Box::pin(PerPartitionStream {
623 schema: Arc::clone(&schema_captured),
624 receiver,
625 _drop_helper: Arc::clone(&abort_helper),
626 reservation: Arc::clone(&reservation),
627 }) as SendableRecordBatchStream
628 })
629 .collect::<Vec<_>>();
630 let fetch = None;
635 let merge_reservation =
636 MemoryConsumer::new(format!("{}[Merge {partition}]", name))
637 .register(context.memory_pool());
638 StreamingMergeBuilder::new()
639 .with_streams(input_streams)
640 .with_schema(schema_captured)
641 .with_expressions(&sort_exprs)
642 .with_metrics(BaselineMetrics::new(&metrics, partition))
643 .with_batch_size(context.session_config().batch_size())
644 .with_fetch(fetch)
645 .with_reservation(merge_reservation)
646 .build()
647 } else {
648 Ok(Box::pin(RepartitionStream {
649 num_input_partitions,
650 num_input_partitions_processed: 0,
651 schema: input.schema(),
652 input: rx.swap_remove(0),
653 _drop_helper: abort_helper,
654 reservation,
655 }) as SendableRecordBatchStream)
656 }
657 })
658 .try_flatten();
659 let stream = RecordBatchStreamAdapter::new(schema, stream);
660 Ok(Box::pin(stream))
661 }
662
663 fn metrics(&self) -> Option<MetricsSet> {
664 Some(self.metrics.clone_inner())
665 }
666
667 fn statistics(&self) -> Result<Statistics> {
668 self.input.statistics()
669 }
670
671 fn cardinality_effect(&self) -> CardinalityEffect {
672 CardinalityEffect::Equal
673 }
674
675 fn try_swapping_with_projection(
676 &self,
677 projection: &ProjectionExec,
678 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
679 if projection.expr().len() >= projection.input().schema().fields().len() {
681 return Ok(None);
682 }
683
684 if projection.benefits_from_input_partitioning()[0]
686 || !all_columns(projection.expr())
687 {
688 return Ok(None);
689 }
690
691 let new_projection = make_with_child(projection, self.input())?;
692
693 let new_partitioning = match self.partitioning() {
694 Partitioning::Hash(partitions, size) => {
695 let mut new_partitions = vec![];
696 for partition in partitions {
697 let Some(new_partition) =
698 update_expr(partition, projection.expr(), false)?
699 else {
700 return Ok(None);
701 };
702 new_partitions.push(new_partition);
703 }
704 Partitioning::Hash(new_partitions, *size)
705 }
706 others => others.clone(),
707 };
708
709 Ok(Some(Arc::new(RepartitionExec::try_new(
710 new_projection,
711 new_partitioning,
712 )?)))
713 }
714}
715
716impl RepartitionExec {
717 pub fn try_new(
721 input: Arc<dyn ExecutionPlan>,
722 partitioning: Partitioning,
723 ) -> Result<Self> {
724 let preserve_order = false;
725 let cache =
726 Self::compute_properties(&input, partitioning.clone(), preserve_order);
727 Ok(RepartitionExec {
728 input,
729 state: Default::default(),
730 metrics: ExecutionPlanMetricsSet::new(),
731 preserve_order,
732 cache,
733 })
734 }
735
736 fn maintains_input_order_helper(
737 input: &Arc<dyn ExecutionPlan>,
738 preserve_order: bool,
739 ) -> Vec<bool> {
740 vec![preserve_order || input.output_partitioning().partition_count() <= 1]
742 }
743
744 fn eq_properties_helper(
745 input: &Arc<dyn ExecutionPlan>,
746 preserve_order: bool,
747 ) -> EquivalenceProperties {
748 let mut eq_properties = input.equivalence_properties().clone();
750 if !Self::maintains_input_order_helper(input, preserve_order)[0] {
752 eq_properties.clear_orderings();
753 }
754 if input.output_partitioning().partition_count() > 1 {
757 eq_properties.clear_per_partition_constants();
758 }
759 eq_properties
760 }
761
762 fn compute_properties(
764 input: &Arc<dyn ExecutionPlan>,
765 partitioning: Partitioning,
766 preserve_order: bool,
767 ) -> PlanProperties {
768 PlanProperties::new(
769 Self::eq_properties_helper(input, preserve_order),
770 partitioning,
771 input.pipeline_behavior(),
772 input.boundedness(),
773 )
774 }
775
776 pub fn with_preserve_order(mut self) -> Self {
784 self.preserve_order =
785 self.input.output_ordering().is_some() &&
787 self.input.output_partitioning().partition_count() > 1;
790 let eq_properties = Self::eq_properties_helper(&self.input, self.preserve_order);
791 self.cache = self.cache.with_eq_properties(eq_properties);
792 self
793 }
794
795 fn sort_exprs(&self) -> Option<&LexOrdering> {
797 if self.preserve_order {
798 self.input.output_ordering()
799 } else {
800 None
801 }
802 }
803
804 async fn pull_from_input(
809 input: Arc<dyn ExecutionPlan>,
810 partition: usize,
811 mut output_channels: HashMap<
812 usize,
813 (DistributionSender<MaybeBatch>, SharedMemoryReservation),
814 >,
815 partitioning: Partitioning,
816 metrics: RepartitionMetrics,
817 context: Arc<TaskContext>,
818 ) -> Result<()> {
819 let mut partitioner =
820 BatchPartitioner::try_new(partitioning, metrics.repartition_time.clone())?;
821
822 let timer = metrics.fetch_time.timer();
824 let mut stream = input.execute(partition, context)?;
825 timer.done();
826
827 let mut batches_until_yield = partitioner.num_partitions();
829 while !output_channels.is_empty() {
830 let timer = metrics.fetch_time.timer();
832 let result = stream.next().await;
833 timer.done();
834
835 let batch = match result {
837 Some(result) => result?,
838 None => break,
839 };
840
841 for res in partitioner.partition_iter(batch)? {
842 let (partition, batch) = res?;
843 let size = batch.get_array_memory_size();
844
845 let timer = metrics.send_time[partition].timer();
846 if let Some((tx, reservation)) = output_channels.get_mut(&partition) {
848 reservation.lock().try_grow(size)?;
849
850 if tx.send(Some(Ok(batch))).await.is_err() {
851 reservation.lock().shrink(size);
853 output_channels.remove(&partition);
854 }
855 }
856 timer.done();
857 }
858
859 if batches_until_yield == 0 {
876 tokio::task::yield_now().await;
877 batches_until_yield = partitioner.num_partitions();
878 } else {
879 batches_until_yield -= 1;
880 }
881 }
882
883 Ok(())
884 }
885
886 async fn wait_for_task(
892 input_task: SpawnedTask<Result<()>>,
893 txs: HashMap<usize, DistributionSender<MaybeBatch>>,
894 ) {
895 match input_task.join().await {
899 Err(e) => {
901 let e = Arc::new(e);
902
903 for (_, tx) in txs {
904 let err = Err(DataFusionError::Context(
905 "Join Error".to_string(),
906 Box::new(DataFusionError::External(Box::new(Arc::clone(&e)))),
907 ));
908 tx.send(Some(err)).await.ok();
909 }
910 }
911 Ok(Err(e)) => {
913 let e = Arc::new(e);
915
916 for (_, tx) in txs {
917 let err = Err(DataFusionError::from(&e));
919 tx.send(Some(err)).await.ok();
920 }
921 }
922 Ok(Ok(())) => {
924 for (_, tx) in txs {
926 tx.send(None).await.ok();
927 }
928 }
929 }
930 }
931}
932
933struct RepartitionStream {
934 num_input_partitions: usize,
936
937 num_input_partitions_processed: usize,
939
940 schema: SchemaRef,
942
943 input: DistributionReceiver<MaybeBatch>,
945
946 _drop_helper: Arc<Vec<SpawnedTask<()>>>,
948
949 reservation: SharedMemoryReservation,
951}
952
953impl Stream for RepartitionStream {
954 type Item = Result<RecordBatch>;
955
956 fn poll_next(
957 mut self: Pin<&mut Self>,
958 cx: &mut Context<'_>,
959 ) -> Poll<Option<Self::Item>> {
960 loop {
961 match self.input.recv().poll_unpin(cx) {
962 Poll::Ready(Some(Some(v))) => {
963 if let Ok(batch) = &v {
964 self.reservation
965 .lock()
966 .shrink(batch.get_array_memory_size());
967 }
968
969 return Poll::Ready(Some(v));
970 }
971 Poll::Ready(Some(None)) => {
972 self.num_input_partitions_processed += 1;
973
974 if self.num_input_partitions == self.num_input_partitions_processed {
975 return Poll::Ready(None);
977 } else {
978 continue;
980 }
981 }
982 Poll::Ready(None) => {
983 return Poll::Ready(None);
984 }
985 Poll::Pending => {
986 return Poll::Pending;
987 }
988 }
989 }
990 }
991}
992
993impl RecordBatchStream for RepartitionStream {
994 fn schema(&self) -> SchemaRef {
996 Arc::clone(&self.schema)
997 }
998}
999
1000struct PerPartitionStream {
1003 schema: SchemaRef,
1005
1006 receiver: DistributionReceiver<MaybeBatch>,
1008
1009 _drop_helper: Arc<Vec<SpawnedTask<()>>>,
1011
1012 reservation: SharedMemoryReservation,
1014}
1015
1016impl Stream for PerPartitionStream {
1017 type Item = Result<RecordBatch>;
1018
1019 fn poll_next(
1020 mut self: Pin<&mut Self>,
1021 cx: &mut Context<'_>,
1022 ) -> Poll<Option<Self::Item>> {
1023 match self.receiver.recv().poll_unpin(cx) {
1024 Poll::Ready(Some(Some(v))) => {
1025 if let Ok(batch) = &v {
1026 self.reservation
1027 .lock()
1028 .shrink(batch.get_array_memory_size());
1029 }
1030 Poll::Ready(Some(v))
1031 }
1032 Poll::Ready(Some(None)) => {
1033 Poll::Ready(None)
1035 }
1036 Poll::Ready(None) => Poll::Ready(None),
1037 Poll::Pending => Poll::Pending,
1038 }
1039 }
1040}
1041
1042impl RecordBatchStream for PerPartitionStream {
1043 fn schema(&self) -> SchemaRef {
1045 Arc::clone(&self.schema)
1046 }
1047}
1048
1049#[cfg(test)]
1050mod tests {
1051 use std::collections::HashSet;
1052
1053 use super::*;
1054 use crate::test::TestMemoryExec;
1055 use crate::{
1056 test::{
1057 assert_is_pending,
1058 exec::{
1059 assert_strong_count_converges_to_zero, BarrierExec, BlockingExec,
1060 ErrorExec, MockExec,
1061 },
1062 },
1063 {collect, expressions::col},
1064 };
1065
1066 use arrow::array::{ArrayRef, StringArray, UInt32Array};
1067 use arrow::datatypes::{DataType, Field, Schema};
1068 use datafusion_common::cast::as_string_array;
1069 use datafusion_common::{arrow_datafusion_err, assert_batches_sorted_eq, exec_err};
1070 use datafusion_execution::runtime_env::RuntimeEnvBuilder;
1071
1072 use tokio::task::JoinSet;
1073
1074 #[tokio::test]
1075 async fn one_to_many_round_robin() -> Result<()> {
1076 let schema = test_schema();
1078 let partition = create_vec_batches(50);
1079 let partitions = vec![partition];
1080
1081 let output_partitions =
1083 repartition(&schema, partitions, Partitioning::RoundRobinBatch(4)).await?;
1084
1085 assert_eq!(4, output_partitions.len());
1086 assert_eq!(13, output_partitions[0].len());
1087 assert_eq!(13, output_partitions[1].len());
1088 assert_eq!(12, output_partitions[2].len());
1089 assert_eq!(12, output_partitions[3].len());
1090
1091 Ok(())
1092 }
1093
1094 #[tokio::test]
1095 async fn many_to_one_round_robin() -> Result<()> {
1096 let schema = test_schema();
1098 let partition = create_vec_batches(50);
1099 let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
1100
1101 let output_partitions =
1103 repartition(&schema, partitions, Partitioning::RoundRobinBatch(1)).await?;
1104
1105 assert_eq!(1, output_partitions.len());
1106 assert_eq!(150, output_partitions[0].len());
1107
1108 Ok(())
1109 }
1110
1111 #[tokio::test]
1112 async fn many_to_many_round_robin() -> Result<()> {
1113 let schema = test_schema();
1115 let partition = create_vec_batches(50);
1116 let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
1117
1118 let output_partitions =
1120 repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await?;
1121
1122 assert_eq!(5, output_partitions.len());
1123 assert_eq!(30, output_partitions[0].len());
1124 assert_eq!(30, output_partitions[1].len());
1125 assert_eq!(30, output_partitions[2].len());
1126 assert_eq!(30, output_partitions[3].len());
1127 assert_eq!(30, output_partitions[4].len());
1128
1129 Ok(())
1130 }
1131
1132 #[tokio::test]
1133 async fn many_to_many_hash_partition() -> Result<()> {
1134 let schema = test_schema();
1136 let partition = create_vec_batches(50);
1137 let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
1138
1139 let output_partitions = repartition(
1140 &schema,
1141 partitions,
1142 Partitioning::Hash(vec![col("c0", &schema)?], 8),
1143 )
1144 .await?;
1145
1146 let total_rows: usize = output_partitions
1147 .iter()
1148 .map(|x| x.iter().map(|x| x.num_rows()).sum::<usize>())
1149 .sum();
1150
1151 assert_eq!(8, output_partitions.len());
1152 assert_eq!(total_rows, 8 * 50 * 3);
1153
1154 Ok(())
1155 }
1156
1157 fn test_schema() -> Arc<Schema> {
1158 Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]))
1159 }
1160
1161 async fn repartition(
1162 schema: &SchemaRef,
1163 input_partitions: Vec<Vec<RecordBatch>>,
1164 partitioning: Partitioning,
1165 ) -> Result<Vec<Vec<RecordBatch>>> {
1166 let task_ctx = Arc::new(TaskContext::default());
1167 let exec =
1169 TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(schema), None)?;
1170 let exec = RepartitionExec::try_new(exec, partitioning)?;
1171
1172 let mut output_partitions = vec![];
1174 for i in 0..exec.partitioning().partition_count() {
1175 let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
1177 let mut batches = vec![];
1178 while let Some(result) = stream.next().await {
1179 batches.push(result?);
1180 }
1181 output_partitions.push(batches);
1182 }
1183 Ok(output_partitions)
1184 }
1185
1186 #[tokio::test]
1187 async fn many_to_many_round_robin_within_tokio_task() -> Result<()> {
1188 let handle: SpawnedTask<Result<Vec<Vec<RecordBatch>>>> =
1189 SpawnedTask::spawn(async move {
1190 let schema = test_schema();
1192 let partition = create_vec_batches(50);
1193 let partitions =
1194 vec![partition.clone(), partition.clone(), partition.clone()];
1195
1196 repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await
1198 });
1199
1200 let output_partitions = handle.join().await.unwrap().unwrap();
1201
1202 assert_eq!(5, output_partitions.len());
1203 assert_eq!(30, output_partitions[0].len());
1204 assert_eq!(30, output_partitions[1].len());
1205 assert_eq!(30, output_partitions[2].len());
1206 assert_eq!(30, output_partitions[3].len());
1207 assert_eq!(30, output_partitions[4].len());
1208
1209 Ok(())
1210 }
1211
1212 #[tokio::test]
1213 async fn unsupported_partitioning() {
1214 let task_ctx = Arc::new(TaskContext::default());
1215 let batch = RecordBatch::try_from_iter(vec![(
1217 "my_awesome_field",
1218 Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
1219 )])
1220 .unwrap();
1221
1222 let schema = batch.schema();
1223 let input = MockExec::new(vec![Ok(batch)], schema);
1224 let partitioning = Partitioning::UnknownPartitioning(1);
1228 let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1229 let output_stream = exec.execute(0, task_ctx).unwrap();
1230
1231 let result_string = crate::common::collect(output_stream)
1233 .await
1234 .unwrap_err()
1235 .to_string();
1236 assert!(
1237 result_string
1238 .contains("Unsupported repartitioning scheme UnknownPartitioning(1)"),
1239 "actual: {result_string}"
1240 );
1241 }
1242
1243 #[tokio::test]
1244 async fn error_for_input_exec() {
1245 let task_ctx = Arc::new(TaskContext::default());
1249 let input = ErrorExec::new();
1250 let partitioning = Partitioning::RoundRobinBatch(1);
1251 let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1252
1253 let output_stream = exec.execute(0, task_ctx).unwrap();
1256
1257 let result_string = crate::common::collect(output_stream)
1259 .await
1260 .unwrap_err()
1261 .to_string();
1262 assert!(
1263 result_string.contains("ErrorExec, unsurprisingly, errored in partition 0"),
1264 "actual: {result_string}"
1265 );
1266 }
1267
1268 #[tokio::test]
1269 async fn repartition_with_error_in_stream() {
1270 let task_ctx = Arc::new(TaskContext::default());
1271 let batch = RecordBatch::try_from_iter(vec![(
1272 "my_awesome_field",
1273 Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
1274 )])
1275 .unwrap();
1276
1277 let err = exec_err!("bad data error");
1280
1281 let schema = batch.schema();
1282 let input = MockExec::new(vec![Ok(batch), err], schema);
1283 let partitioning = Partitioning::RoundRobinBatch(1);
1284 let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1285
1286 let output_stream = exec.execute(0, task_ctx).unwrap();
1289
1290 let result_string = crate::common::collect(output_stream)
1292 .await
1293 .unwrap_err()
1294 .to_string();
1295 assert!(
1296 result_string.contains("bad data error"),
1297 "actual: {result_string}"
1298 );
1299 }
1300
1301 #[tokio::test]
1302 async fn repartition_with_delayed_stream() {
1303 let task_ctx = Arc::new(TaskContext::default());
1304 let batch1 = RecordBatch::try_from_iter(vec![(
1305 "my_awesome_field",
1306 Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
1307 )])
1308 .unwrap();
1309
1310 let batch2 = RecordBatch::try_from_iter(vec![(
1311 "my_awesome_field",
1312 Arc::new(StringArray::from(vec!["frob", "baz"])) as ArrayRef,
1313 )])
1314 .unwrap();
1315
1316 let schema = batch1.schema();
1319 let expected_batches = vec![batch1.clone(), batch2.clone()];
1320 let input = MockExec::new(vec![Ok(batch1), Ok(batch2)], schema);
1321 let partitioning = Partitioning::RoundRobinBatch(1);
1322
1323 let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1324
1325 let expected = vec![
1326 "+------------------+",
1327 "| my_awesome_field |",
1328 "+------------------+",
1329 "| foo |",
1330 "| bar |",
1331 "| frob |",
1332 "| baz |",
1333 "+------------------+",
1334 ];
1335
1336 assert_batches_sorted_eq!(&expected, &expected_batches);
1337
1338 let output_stream = exec.execute(0, task_ctx).unwrap();
1339 let batches = crate::common::collect(output_stream).await.unwrap();
1340
1341 assert_batches_sorted_eq!(&expected, &batches);
1342 }
1343
1344 #[tokio::test]
1345 async fn robin_repartition_with_dropping_output_stream() {
1346 let task_ctx = Arc::new(TaskContext::default());
1347 let partitioning = Partitioning::RoundRobinBatch(2);
1348 let input = Arc::new(make_barrier_exec());
1351
1352 let exec = RepartitionExec::try_new(
1354 Arc::clone(&input) as Arc<dyn ExecutionPlan>,
1355 partitioning,
1356 )
1357 .unwrap();
1358
1359 let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
1360 let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
1361
1362 drop(output_stream0);
1365
1366 let mut background_task = JoinSet::new();
1368 background_task.spawn(async move {
1369 input.wait().await;
1370 });
1371
1372 let batches = crate::common::collect(output_stream1).await.unwrap();
1374
1375 let expected = vec![
1376 "+------------------+",
1377 "| my_awesome_field |",
1378 "+------------------+",
1379 "| baz |",
1380 "| frob |",
1381 "| gaz |",
1382 "| grob |",
1383 "+------------------+",
1384 ];
1385
1386 assert_batches_sorted_eq!(&expected, &batches);
1387 }
1388
1389 #[tokio::test]
1390 async fn hash_repartition_with_dropping_output_stream() {
1394 let task_ctx = Arc::new(TaskContext::default());
1395 let partitioning = Partitioning::Hash(
1396 vec![Arc::new(crate::expressions::Column::new(
1397 "my_awesome_field",
1398 0,
1399 ))],
1400 2,
1401 );
1402
1403 let input = Arc::new(make_barrier_exec());
1405 let exec = RepartitionExec::try_new(
1406 Arc::clone(&input) as Arc<dyn ExecutionPlan>,
1407 partitioning.clone(),
1408 )
1409 .unwrap();
1410 let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
1411 let mut background_task = JoinSet::new();
1412 background_task.spawn(async move {
1413 input.wait().await;
1414 });
1415 let batches_without_drop = crate::common::collect(output_stream1).await.unwrap();
1416
1417 let items_vec = str_batches_to_vec(&batches_without_drop);
1419 let items_set: HashSet<&str> = items_vec.iter().copied().collect();
1420 assert_eq!(items_vec.len(), items_set.len());
1421 let source_str_set: HashSet<&str> =
1422 ["foo", "bar", "frob", "baz", "goo", "gar", "grob", "gaz"]
1423 .iter()
1424 .copied()
1425 .collect();
1426 assert_eq!(items_set.difference(&source_str_set).count(), 0);
1427
1428 let input = Arc::new(make_barrier_exec());
1430 let exec = RepartitionExec::try_new(
1431 Arc::clone(&input) as Arc<dyn ExecutionPlan>,
1432 partitioning,
1433 )
1434 .unwrap();
1435 let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
1436 let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
1437 drop(output_stream0);
1440 let mut background_task = JoinSet::new();
1441 background_task.spawn(async move {
1442 input.wait().await;
1443 });
1444 let batches_with_drop = crate::common::collect(output_stream1).await.unwrap();
1445
1446 assert_eq!(batches_without_drop, batches_with_drop);
1447 }
1448
1449 fn str_batches_to_vec(batches: &[RecordBatch]) -> Vec<&str> {
1450 batches
1451 .iter()
1452 .flat_map(|batch| {
1453 assert_eq!(batch.columns().len(), 1);
1454 let string_array = as_string_array(batch.column(0))
1455 .expect("Unexpected type for repartitioned batch");
1456
1457 string_array
1458 .iter()
1459 .map(|v| v.expect("Unexpected null"))
1460 .collect::<Vec<_>>()
1461 })
1462 .collect::<Vec<_>>()
1463 }
1464
1465 fn make_barrier_exec() -> BarrierExec {
1467 let batch1 = RecordBatch::try_from_iter(vec![(
1468 "my_awesome_field",
1469 Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
1470 )])
1471 .unwrap();
1472
1473 let batch2 = RecordBatch::try_from_iter(vec![(
1474 "my_awesome_field",
1475 Arc::new(StringArray::from(vec!["frob", "baz"])) as ArrayRef,
1476 )])
1477 .unwrap();
1478
1479 let batch3 = RecordBatch::try_from_iter(vec![(
1480 "my_awesome_field",
1481 Arc::new(StringArray::from(vec!["goo", "gar"])) as ArrayRef,
1482 )])
1483 .unwrap();
1484
1485 let batch4 = RecordBatch::try_from_iter(vec![(
1486 "my_awesome_field",
1487 Arc::new(StringArray::from(vec!["grob", "gaz"])) as ArrayRef,
1488 )])
1489 .unwrap();
1490
1491 let schema = batch1.schema();
1494 BarrierExec::new(vec![vec![batch1, batch2], vec![batch3, batch4]], schema)
1495 }
1496
1497 #[tokio::test]
1498 async fn test_drop_cancel() -> Result<()> {
1499 let task_ctx = Arc::new(TaskContext::default());
1500 let schema =
1501 Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]));
1502
1503 let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 2));
1504 let refs = blocking_exec.refs();
1505 let repartition_exec = Arc::new(RepartitionExec::try_new(
1506 blocking_exec,
1507 Partitioning::UnknownPartitioning(1),
1508 )?);
1509
1510 let fut = collect(repartition_exec, task_ctx);
1511 let mut fut = fut.boxed();
1512
1513 assert_is_pending(&mut fut);
1514 drop(fut);
1515 assert_strong_count_converges_to_zero(refs).await;
1516
1517 Ok(())
1518 }
1519
1520 #[tokio::test]
1521 async fn hash_repartition_avoid_empty_batch() -> Result<()> {
1522 let task_ctx = Arc::new(TaskContext::default());
1523 let batch = RecordBatch::try_from_iter(vec![(
1524 "a",
1525 Arc::new(StringArray::from(vec!["foo"])) as ArrayRef,
1526 )])
1527 .unwrap();
1528 let partitioning = Partitioning::Hash(
1529 vec![Arc::new(crate::expressions::Column::new("a", 0))],
1530 2,
1531 );
1532 let schema = batch.schema();
1533 let input = MockExec::new(vec![Ok(batch)], schema);
1534 let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1535 let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
1536 let batch0 = crate::common::collect(output_stream0).await.unwrap();
1537 let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
1538 let batch1 = crate::common::collect(output_stream1).await.unwrap();
1539 assert!(batch0.is_empty() || batch1.is_empty());
1540 Ok(())
1541 }
1542
1543 #[tokio::test]
1544 async fn oom() -> Result<()> {
1545 let schema = test_schema();
1547 let partition = create_vec_batches(50);
1548 let input_partitions = vec![partition];
1549 let partitioning = Partitioning::RoundRobinBatch(4);
1550
1551 let runtime = RuntimeEnvBuilder::default()
1553 .with_memory_limit(1, 1.0)
1554 .build_arc()?;
1555
1556 let task_ctx = TaskContext::default().with_runtime(runtime);
1557 let task_ctx = Arc::new(task_ctx);
1558
1559 let exec =
1561 TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
1562 let exec = RepartitionExec::try_new(exec, partitioning)?;
1563
1564 for i in 0..exec.partitioning().partition_count() {
1566 let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
1567 let err =
1568 arrow_datafusion_err!(stream.next().await.unwrap().unwrap_err().into());
1569 let err = err.find_root();
1570 assert!(
1571 matches!(err, DataFusionError::ResourcesExhausted(_)),
1572 "Wrong error type: {err}",
1573 );
1574 }
1575
1576 Ok(())
1577 }
1578
1579 fn create_vec_batches(n: usize) -> Vec<RecordBatch> {
1581 let batch = create_batch();
1582 (0..n).map(|_| batch.clone()).collect()
1583 }
1584
1585 fn create_batch() -> RecordBatch {
1587 let schema = test_schema();
1588 RecordBatch::try_new(
1589 schema,
1590 vec![Arc::new(UInt32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]))],
1591 )
1592 .unwrap()
1593 }
1594}
1595
1596#[cfg(test)]
1597mod test {
1598 use arrow::compute::SortOptions;
1599 use arrow::datatypes::{DataType, Field, Schema};
1600
1601 use super::*;
1602 use crate::test::TestMemoryExec;
1603 use crate::union::UnionExec;
1604
1605 use datafusion_physical_expr::expressions::col;
1606 use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
1607
1608 macro_rules! assert_plan {
1614 ($EXPECTED_PLAN_LINES: expr, $PLAN: expr) => {
1615 let physical_plan = $PLAN;
1616 let formatted = crate::displayable(&physical_plan).indent(true).to_string();
1617 let actual: Vec<&str> = formatted.trim().lines().collect();
1618
1619 let expected_plan_lines: Vec<&str> = $EXPECTED_PLAN_LINES
1620 .iter().map(|s| *s).collect();
1621
1622 assert_eq!(
1623 expected_plan_lines, actual,
1624 "\n**Original Plan Mismatch\n\nexpected:\n\n{expected_plan_lines:#?}\nactual:\n\n{actual:#?}\n\n"
1625 );
1626 };
1627 }
1628
1629 #[tokio::test]
1630 async fn test_preserve_order() -> Result<()> {
1631 let schema = test_schema();
1632 let sort_exprs = sort_exprs(&schema);
1633 let source1 = sorted_memory_exec(&schema, sort_exprs.clone());
1634 let source2 = sorted_memory_exec(&schema, sort_exprs);
1635 let union = UnionExec::new(vec![source1, source2]);
1637 let exec =
1638 RepartitionExec::try_new(Arc::new(union), Partitioning::RoundRobinBatch(10))
1639 .unwrap()
1640 .with_preserve_order();
1641
1642 let expected_plan = [
1644 "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2, preserve_order=true, sort_exprs=c0@0 ASC",
1645 " UnionExec",
1646 " DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC",
1647 " DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC",
1648 ];
1649 assert_plan!(expected_plan, exec);
1650 Ok(())
1651 }
1652
1653 #[tokio::test]
1654 async fn test_preserve_order_one_partition() -> Result<()> {
1655 let schema = test_schema();
1656 let sort_exprs = sort_exprs(&schema);
1657 let source = sorted_memory_exec(&schema, sort_exprs);
1658 let exec = RepartitionExec::try_new(source, Partitioning::RoundRobinBatch(10))
1660 .unwrap()
1661 .with_preserve_order();
1662
1663 let expected_plan = [
1665 "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1",
1666 " DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC",
1667 ];
1668 assert_plan!(expected_plan, exec);
1669 Ok(())
1670 }
1671
1672 #[tokio::test]
1673 async fn test_preserve_order_input_not_sorted() -> Result<()> {
1674 let schema = test_schema();
1675 let source1 = memory_exec(&schema);
1676 let source2 = memory_exec(&schema);
1677 let union = UnionExec::new(vec![source1, source2]);
1679 let exec =
1680 RepartitionExec::try_new(Arc::new(union), Partitioning::RoundRobinBatch(10))
1681 .unwrap()
1682 .with_preserve_order();
1683
1684 let expected_plan = [
1686 "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2",
1687 " UnionExec",
1688 " DataSourceExec: partitions=1, partition_sizes=[0]",
1689 " DataSourceExec: partitions=1, partition_sizes=[0]",
1690 ];
1691 assert_plan!(expected_plan, exec);
1692 Ok(())
1693 }
1694
1695 fn test_schema() -> Arc<Schema> {
1696 Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]))
1697 }
1698
1699 fn sort_exprs(schema: &Schema) -> LexOrdering {
1700 let options = SortOptions::default();
1701 LexOrdering::new(vec![PhysicalSortExpr {
1702 expr: col("c0", schema).unwrap(),
1703 options,
1704 }])
1705 }
1706
1707 fn memory_exec(schema: &SchemaRef) -> Arc<dyn ExecutionPlan> {
1708 TestMemoryExec::try_new_exec(&[vec![]], Arc::clone(schema), None).unwrap()
1709 }
1710
1711 fn sorted_memory_exec(
1712 schema: &SchemaRef,
1713 sort_exprs: LexOrdering,
1714 ) -> Arc<dyn ExecutionPlan> {
1715 Arc::new(TestMemoryExec::update_cache(Arc::new(
1716 TestMemoryExec::try_new(&[vec![]], Arc::clone(schema), None)
1717 .unwrap()
1718 .try_with_sort_information(vec![sort_exprs])
1719 .unwrap(),
1720 )))
1721 }
1722}