datafusion_physical_plan/repartition/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! This file implements the [`RepartitionExec`]  operator, which maps N input
19//! partitions to M output partitions based on a partitioning scheme, optionally
20//! maintaining the order of the input rows in the output.
21
22use std::pin::Pin;
23use std::sync::Arc;
24use std::task::{Context, Poll};
25use std::{any::Any, vec};
26
27use super::common::SharedMemoryReservation;
28use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
29use super::{
30    DisplayAs, ExecutionPlanProperties, RecordBatchStream, SendableRecordBatchStream,
31};
32use crate::execution_plan::CardinalityEffect;
33use crate::hash_utils::create_hashes;
34use crate::metrics::BaselineMetrics;
35use crate::projection::{all_columns, make_with_child, update_expr, ProjectionExec};
36use crate::repartition::distributor_channels::{
37    channels, partition_aware_channels, DistributionReceiver, DistributionSender,
38};
39use crate::sorts::streaming_merge::StreamingMergeBuilder;
40use crate::stream::RecordBatchStreamAdapter;
41use crate::{DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, Statistics};
42
43use arrow::array::{PrimitiveArray, RecordBatch, RecordBatchOptions};
44use arrow::compute::take_arrays;
45use arrow::datatypes::{SchemaRef, UInt32Type};
46use datafusion_common::utils::transpose;
47use datafusion_common::HashMap;
48use datafusion_common::{not_impl_err, DataFusionError, Result};
49use datafusion_common_runtime::SpawnedTask;
50use datafusion_execution::memory_pool::MemoryConsumer;
51use datafusion_execution::TaskContext;
52use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr};
53use datafusion_physical_expr_common::sort_expr::LexOrdering;
54
55use futures::stream::Stream;
56use futures::{FutureExt, StreamExt, TryStreamExt};
57use log::trace;
58use parking_lot::Mutex;
59
60mod distributor_channels;
61
62type MaybeBatch = Option<Result<RecordBatch>>;
63type InputPartitionsToCurrentPartitionSender = Vec<DistributionSender<MaybeBatch>>;
64type InputPartitionsToCurrentPartitionReceiver = Vec<DistributionReceiver<MaybeBatch>>;
65
66/// Inner state of [`RepartitionExec`].
67#[derive(Debug)]
68struct RepartitionExecState {
69    /// Channels for sending batches from input partitions to output partitions.
70    /// Key is the partition number.
71    channels: HashMap<
72        usize,
73        (
74            InputPartitionsToCurrentPartitionSender,
75            InputPartitionsToCurrentPartitionReceiver,
76            SharedMemoryReservation,
77        ),
78    >,
79
80    /// Helper that ensures that that background job is killed once it is no longer needed.
81    abort_helper: Arc<Vec<SpawnedTask<()>>>,
82}
83
84impl RepartitionExecState {
85    fn new(
86        input: Arc<dyn ExecutionPlan>,
87        partitioning: Partitioning,
88        metrics: ExecutionPlanMetricsSet,
89        preserve_order: bool,
90        name: String,
91        context: Arc<TaskContext>,
92    ) -> Self {
93        let num_input_partitions = input.output_partitioning().partition_count();
94        let num_output_partitions = partitioning.partition_count();
95
96        let (txs, rxs) = if preserve_order {
97            let (txs, rxs) =
98                partition_aware_channels(num_input_partitions, num_output_partitions);
99            // Take transpose of senders and receivers. `state.channels` keeps track of entries per output partition
100            let txs = transpose(txs);
101            let rxs = transpose(rxs);
102            (txs, rxs)
103        } else {
104            // create one channel per *output* partition
105            // note we use a custom channel that ensures there is always data for each receiver
106            // but limits the amount of buffering if required.
107            let (txs, rxs) = channels(num_output_partitions);
108            // Clone sender for each input partitions
109            let txs = txs
110                .into_iter()
111                .map(|item| vec![item; num_input_partitions])
112                .collect::<Vec<_>>();
113            let rxs = rxs.into_iter().map(|item| vec![item]).collect::<Vec<_>>();
114            (txs, rxs)
115        };
116
117        let mut channels = HashMap::with_capacity(txs.len());
118        for (partition, (tx, rx)) in txs.into_iter().zip(rxs).enumerate() {
119            let reservation = Arc::new(Mutex::new(
120                MemoryConsumer::new(format!("{}[{partition}]", name))
121                    .register(context.memory_pool()),
122            ));
123            channels.insert(partition, (tx, rx, reservation));
124        }
125
126        // launch one async task per *input* partition
127        let mut spawned_tasks = Vec::with_capacity(num_input_partitions);
128        for i in 0..num_input_partitions {
129            let txs: HashMap<_, _> = channels
130                .iter()
131                .map(|(partition, (tx, _rx, reservation))| {
132                    (*partition, (tx[i].clone(), Arc::clone(reservation)))
133                })
134                .collect();
135
136            let r_metrics = RepartitionMetrics::new(i, num_output_partitions, &metrics);
137
138            let input_task = SpawnedTask::spawn(RepartitionExec::pull_from_input(
139                Arc::clone(&input),
140                i,
141                txs.clone(),
142                partitioning.clone(),
143                r_metrics,
144                Arc::clone(&context),
145            ));
146
147            // In a separate task, wait for each input to be done
148            // (and pass along any errors, including panic!s)
149            let wait_for_task = SpawnedTask::spawn(RepartitionExec::wait_for_task(
150                input_task,
151                txs.into_iter()
152                    .map(|(partition, (tx, _reservation))| (partition, tx))
153                    .collect(),
154            ));
155            spawned_tasks.push(wait_for_task);
156        }
157
158        Self {
159            channels,
160            abort_helper: Arc::new(spawned_tasks),
161        }
162    }
163}
164
165/// Lazily initialized state
166///
167/// Note that the state is initialized ONCE for all partitions by a single task(thread).
168/// This may take a short while.  It is also like that multiple threads
169/// call execute at the same time, because we have just started "target partitions" tasks
170/// which is commonly set to the number of CPU cores and all call execute at the same time.
171///
172/// Thus, use a **tokio** `OnceCell` for this initialization so as not to waste CPU cycles
173/// in a mutex lock but instead allow other threads to do something useful.
174///
175/// Uses a parking_lot `Mutex` to control other accesses as they are very short duration
176///  (e.g. removing channels on completion) where the overhead of `await` is not warranted.
177type LazyState = Arc<tokio::sync::OnceCell<Mutex<RepartitionExecState>>>;
178
179/// A utility that can be used to partition batches based on [`Partitioning`]
180pub struct BatchPartitioner {
181    state: BatchPartitionerState,
182    timer: metrics::Time,
183}
184
185enum BatchPartitionerState {
186    Hash {
187        random_state: ahash::RandomState,
188        exprs: Vec<Arc<dyn PhysicalExpr>>,
189        num_partitions: usize,
190        hash_buffer: Vec<u64>,
191    },
192    RoundRobin {
193        num_partitions: usize,
194        next_idx: usize,
195    },
196}
197
198impl BatchPartitioner {
199    /// Create a new [`BatchPartitioner`] with the provided [`Partitioning`]
200    ///
201    /// The time spent repartitioning will be recorded to `timer`
202    pub fn try_new(partitioning: Partitioning, timer: metrics::Time) -> Result<Self> {
203        let state = match partitioning {
204            Partitioning::RoundRobinBatch(num_partitions) => {
205                BatchPartitionerState::RoundRobin {
206                    num_partitions,
207                    next_idx: 0,
208                }
209            }
210            Partitioning::Hash(exprs, num_partitions) => BatchPartitionerState::Hash {
211                exprs,
212                num_partitions,
213                // Use fixed random hash
214                random_state: ahash::RandomState::with_seeds(0, 0, 0, 0),
215                hash_buffer: vec![],
216            },
217            other => return not_impl_err!("Unsupported repartitioning scheme {other:?}"),
218        };
219
220        Ok(Self { state, timer })
221    }
222
223    /// Partition the provided [`RecordBatch`] into one or more partitioned [`RecordBatch`]
224    /// based on the [`Partitioning`] specified on construction
225    ///
226    /// `f` will be called for each partitioned [`RecordBatch`] with the corresponding
227    /// partition index. Any error returned by `f` will be immediately returned by this
228    /// function without attempting to publish further [`RecordBatch`]
229    ///
230    /// The time spent repartitioning, not including time spent in `f` will be recorded
231    /// to the [`metrics::Time`] provided on construction
232    pub fn partition<F>(&mut self, batch: RecordBatch, mut f: F) -> Result<()>
233    where
234        F: FnMut(usize, RecordBatch) -> Result<()>,
235    {
236        self.partition_iter(batch)?.try_for_each(|res| match res {
237            Ok((partition, batch)) => f(partition, batch),
238            Err(e) => Err(e),
239        })
240    }
241
242    /// Actual implementation of [`partition`](Self::partition).
243    ///
244    /// The reason this was pulled out is that we need to have a variant of `partition` that works w/ sync functions,
245    /// and one that works w/ async. Using an iterator as an intermediate representation was the best way to achieve
246    /// this (so we don't need to clone the entire implementation).
247    fn partition_iter(
248        &mut self,
249        batch: RecordBatch,
250    ) -> Result<impl Iterator<Item = Result<(usize, RecordBatch)>> + Send + '_> {
251        let it: Box<dyn Iterator<Item = Result<(usize, RecordBatch)>> + Send> =
252            match &mut self.state {
253                BatchPartitionerState::RoundRobin {
254                    num_partitions,
255                    next_idx,
256                } => {
257                    let idx = *next_idx;
258                    *next_idx = (*next_idx + 1) % *num_partitions;
259                    Box::new(std::iter::once(Ok((idx, batch))))
260                }
261                BatchPartitionerState::Hash {
262                    random_state,
263                    exprs,
264                    num_partitions: partitions,
265                    hash_buffer,
266                } => {
267                    // Tracking time required for distributing indexes across output partitions
268                    let timer = self.timer.timer();
269
270                    let arrays = exprs
271                        .iter()
272                        .map(|expr| expr.evaluate(&batch)?.into_array(batch.num_rows()))
273                        .collect::<Result<Vec<_>>>()?;
274
275                    hash_buffer.clear();
276                    hash_buffer.resize(batch.num_rows(), 0);
277
278                    create_hashes(&arrays, random_state, hash_buffer)?;
279
280                    let mut indices: Vec<_> = (0..*partitions)
281                        .map(|_| Vec::with_capacity(batch.num_rows()))
282                        .collect();
283
284                    for (index, hash) in hash_buffer.iter().enumerate() {
285                        indices[(*hash % *partitions as u64) as usize].push(index as u32);
286                    }
287
288                    // Finished building index-arrays for output partitions
289                    timer.done();
290
291                    // Borrowing partitioner timer to prevent moving `self` to closure
292                    let partitioner_timer = &self.timer;
293                    let it = indices
294                        .into_iter()
295                        .enumerate()
296                        .filter_map(|(partition, indices)| {
297                            let indices: PrimitiveArray<UInt32Type> = indices.into();
298                            (!indices.is_empty()).then_some((partition, indices))
299                        })
300                        .map(move |(partition, indices)| {
301                            // Tracking time required for repartitioned batches construction
302                            let _timer = partitioner_timer.timer();
303
304                            // Produce batches based on indices
305                            let columns = take_arrays(batch.columns(), &indices, None)?;
306
307                            let mut options = RecordBatchOptions::new();
308                            options = options.with_row_count(Some(indices.len()));
309                            let batch = RecordBatch::try_new_with_options(
310                                batch.schema(),
311                                columns,
312                                &options,
313                            )
314                            .unwrap();
315
316                            Ok((partition, batch))
317                        });
318
319                    Box::new(it)
320                }
321            };
322
323        Ok(it)
324    }
325
326    // return the number of output partitions
327    fn num_partitions(&self) -> usize {
328        match self.state {
329            BatchPartitionerState::RoundRobin { num_partitions, .. } => num_partitions,
330            BatchPartitionerState::Hash { num_partitions, .. } => num_partitions,
331        }
332    }
333}
334
335/// Maps `N` input partitions to `M` output partitions based on a
336/// [`Partitioning`] scheme.
337///
338/// # Background
339///
340/// DataFusion, like most other commercial systems, with the
341/// notable exception of DuckDB, uses the "Exchange Operator" based
342/// approach to parallelism which works well in practice given
343/// sufficient care in implementation.
344///
345/// DataFusion's planner picks the target number of partitions and
346/// then [`RepartitionExec`] redistributes [`RecordBatch`]es to that number
347/// of output partitions.
348///
349/// For example, given `target_partitions=3` (trying to use 3 cores)
350/// but scanning an input with 2 partitions, `RepartitionExec` can be
351/// used to get 3 even streams of `RecordBatch`es
352///
353///
354///```text
355///        ▲                  ▲                  ▲
356///        │                  │                  │
357///        │                  │                  │
358///        │                  │                  │
359///┌───────────────┐  ┌───────────────┐  ┌───────────────┐
360///│    GroupBy    │  │    GroupBy    │  │    GroupBy    │
361///│   (Partial)   │  │   (Partial)   │  │   (Partial)   │
362///└───────────────┘  └───────────────┘  └───────────────┘
363///        ▲                  ▲                  ▲
364///        └──────────────────┼──────────────────┘
365///                           │
366///              ┌─────────────────────────┐
367///              │     RepartitionExec     │
368///              │   (hash/round robin)    │
369///              └─────────────────────────┘
370///                         ▲   ▲
371///             ┌───────────┘   └───────────┐
372///             │                           │
373///             │                           │
374///        .─────────.                 .─────────.
375///     ,─'           '─.           ,─'           '─.
376///    ;      Input      :         ;      Input      :
377///    :   Partition 0   ;         :   Partition 1   ;
378///     ╲               ╱           ╲               ╱
379///      '─.         ,─'             '─.         ,─'
380///         `───────'                   `───────'
381///```
382///
383/// # Error Handling
384///
385/// If any of the input partitions return an error, the error is propagated to
386/// all output partitions and inputs are not polled again.
387///
388/// # Output Ordering
389///
390/// If more than one stream is being repartitioned, the output will be some
391/// arbitrary interleaving (and thus unordered) unless
392/// [`Self::with_preserve_order`] specifies otherwise.
393///
394/// # Footnote
395///
396/// The "Exchange Operator" was first described in the 1989 paper
397/// [Encapsulation of parallelism in the Volcano query processing
398/// system Paper](https://dl.acm.org/doi/pdf/10.1145/93605.98720)
399/// which uses the term "Exchange" for the concept of repartitioning
400/// data across threads.
401#[derive(Debug, Clone)]
402pub struct RepartitionExec {
403    /// Input execution plan
404    input: Arc<dyn ExecutionPlan>,
405    /// Inner state that is initialized when the first output stream is created.
406    state: LazyState,
407    /// Execution metrics
408    metrics: ExecutionPlanMetricsSet,
409    /// Boolean flag to decide whether to preserve ordering. If true means
410    /// `SortPreservingRepartitionExec`, false means `RepartitionExec`.
411    preserve_order: bool,
412    /// Cache holding plan properties like equivalences, output partitioning etc.
413    cache: PlanProperties,
414}
415
416#[derive(Debug, Clone)]
417struct RepartitionMetrics {
418    /// Time in nanos to execute child operator and fetch batches
419    fetch_time: metrics::Time,
420    /// Repartitioning elapsed time in nanos
421    repartition_time: metrics::Time,
422    /// Time in nanos for sending resulting batches to channels.
423    ///
424    /// One metric per output partition.
425    send_time: Vec<metrics::Time>,
426}
427
428impl RepartitionMetrics {
429    pub fn new(
430        input_partition: usize,
431        num_output_partitions: usize,
432        metrics: &ExecutionPlanMetricsSet,
433    ) -> Self {
434        // Time in nanos to execute child operator and fetch batches
435        let fetch_time =
436            MetricBuilder::new(metrics).subset_time("fetch_time", input_partition);
437
438        // Time in nanos to perform repartitioning
439        let repartition_time =
440            MetricBuilder::new(metrics).subset_time("repartition_time", input_partition);
441
442        // Time in nanos for sending resulting batches to channels
443        let send_time = (0..num_output_partitions)
444            .map(|output_partition| {
445                let label =
446                    metrics::Label::new("outputPartition", output_partition.to_string());
447                MetricBuilder::new(metrics)
448                    .with_label(label)
449                    .subset_time("send_time", input_partition)
450            })
451            .collect();
452
453        Self {
454            fetch_time,
455            repartition_time,
456            send_time,
457        }
458    }
459}
460
461impl RepartitionExec {
462    /// Input execution plan
463    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
464        &self.input
465    }
466
467    /// Partitioning scheme to use
468    pub fn partitioning(&self) -> &Partitioning {
469        &self.cache.partitioning
470    }
471
472    /// Get preserve_order flag of the RepartitionExecutor
473    /// `true` means `SortPreservingRepartitionExec`, `false` means `RepartitionExec`
474    pub fn preserve_order(&self) -> bool {
475        self.preserve_order
476    }
477
478    /// Get name used to display this Exec
479    pub fn name(&self) -> &str {
480        "RepartitionExec"
481    }
482}
483
484impl DisplayAs for RepartitionExec {
485    fn fmt_as(
486        &self,
487        t: DisplayFormatType,
488        f: &mut std::fmt::Formatter,
489    ) -> std::fmt::Result {
490        match t {
491            DisplayFormatType::Default | DisplayFormatType::Verbose => {
492                write!(
493                    f,
494                    "{}: partitioning={}, input_partitions={}",
495                    self.name(),
496                    self.partitioning(),
497                    self.input.output_partitioning().partition_count()
498                )?;
499
500                if self.preserve_order {
501                    write!(f, ", preserve_order=true")?;
502                }
503
504                if let Some(sort_exprs) = self.sort_exprs() {
505                    write!(f, ", sort_exprs={}", sort_exprs.clone())?;
506                }
507                Ok(())
508            }
509        }
510    }
511}
512
513impl ExecutionPlan for RepartitionExec {
514    fn name(&self) -> &'static str {
515        "RepartitionExec"
516    }
517
518    /// Return a reference to Any that can be used for downcasting
519    fn as_any(&self) -> &dyn Any {
520        self
521    }
522
523    fn properties(&self) -> &PlanProperties {
524        &self.cache
525    }
526
527    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
528        vec![&self.input]
529    }
530
531    fn with_new_children(
532        self: Arc<Self>,
533        mut children: Vec<Arc<dyn ExecutionPlan>>,
534    ) -> Result<Arc<dyn ExecutionPlan>> {
535        let mut repartition = RepartitionExec::try_new(
536            children.swap_remove(0),
537            self.partitioning().clone(),
538        )?;
539        if self.preserve_order {
540            repartition = repartition.with_preserve_order();
541        }
542        Ok(Arc::new(repartition))
543    }
544
545    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
546        vec![matches!(self.partitioning(), Partitioning::Hash(_, _))]
547    }
548
549    fn maintains_input_order(&self) -> Vec<bool> {
550        Self::maintains_input_order_helper(self.input(), self.preserve_order)
551    }
552
553    fn execute(
554        &self,
555        partition: usize,
556        context: Arc<TaskContext>,
557    ) -> Result<SendableRecordBatchStream> {
558        trace!(
559            "Start {}::execute for partition: {}",
560            self.name(),
561            partition
562        );
563
564        let lazy_state = Arc::clone(&self.state);
565        let input = Arc::clone(&self.input);
566        let partitioning = self.partitioning().clone();
567        let metrics = self.metrics.clone();
568        let preserve_order = self.preserve_order;
569        let name = self.name().to_owned();
570        let schema = self.schema();
571        let schema_captured = Arc::clone(&schema);
572
573        // Get existing ordering to use for merging
574        let sort_exprs = self.sort_exprs().cloned().unwrap_or_default();
575
576        let stream = futures::stream::once(async move {
577            let num_input_partitions = input.output_partitioning().partition_count();
578
579            let input_captured = Arc::clone(&input);
580            let metrics_captured = metrics.clone();
581            let name_captured = name.clone();
582            let context_captured = Arc::clone(&context);
583            let state = lazy_state
584                .get_or_init(|| async move {
585                    Mutex::new(RepartitionExecState::new(
586                        input_captured,
587                        partitioning,
588                        metrics_captured,
589                        preserve_order,
590                        name_captured,
591                        context_captured,
592                    ))
593                })
594                .await;
595
596            // lock scope
597            let (mut rx, reservation, abort_helper) = {
598                // lock mutexes
599                let mut state = state.lock();
600
601                // now return stream for the specified *output* partition which will
602                // read from the channel
603                let (_tx, rx, reservation) = state
604                    .channels
605                    .remove(&partition)
606                    .expect("partition not used yet");
607
608                (rx, reservation, Arc::clone(&state.abort_helper))
609            };
610
611            trace!(
612                "Before returning stream in {}::execute for partition: {}",
613                name,
614                partition
615            );
616
617            if preserve_order {
618                // Store streams from all the input partitions:
619                let input_streams = rx
620                    .into_iter()
621                    .map(|receiver| {
622                        Box::pin(PerPartitionStream {
623                            schema: Arc::clone(&schema_captured),
624                            receiver,
625                            _drop_helper: Arc::clone(&abort_helper),
626                            reservation: Arc::clone(&reservation),
627                        }) as SendableRecordBatchStream
628                    })
629                    .collect::<Vec<_>>();
630                // Note that receiver size (`rx.len()`) and `num_input_partitions` are same.
631
632                // Merge streams (while preserving ordering) coming from
633                // input partitions to this partition:
634                let fetch = None;
635                let merge_reservation =
636                    MemoryConsumer::new(format!("{}[Merge {partition}]", name))
637                        .register(context.memory_pool());
638                StreamingMergeBuilder::new()
639                    .with_streams(input_streams)
640                    .with_schema(schema_captured)
641                    .with_expressions(&sort_exprs)
642                    .with_metrics(BaselineMetrics::new(&metrics, partition))
643                    .with_batch_size(context.session_config().batch_size())
644                    .with_fetch(fetch)
645                    .with_reservation(merge_reservation)
646                    .build()
647            } else {
648                Ok(Box::pin(RepartitionStream {
649                    num_input_partitions,
650                    num_input_partitions_processed: 0,
651                    schema: input.schema(),
652                    input: rx.swap_remove(0),
653                    _drop_helper: abort_helper,
654                    reservation,
655                }) as SendableRecordBatchStream)
656            }
657        })
658        .try_flatten();
659        let stream = RecordBatchStreamAdapter::new(schema, stream);
660        Ok(Box::pin(stream))
661    }
662
663    fn metrics(&self) -> Option<MetricsSet> {
664        Some(self.metrics.clone_inner())
665    }
666
667    fn statistics(&self) -> Result<Statistics> {
668        self.input.statistics()
669    }
670
671    fn cardinality_effect(&self) -> CardinalityEffect {
672        CardinalityEffect::Equal
673    }
674
675    fn try_swapping_with_projection(
676        &self,
677        projection: &ProjectionExec,
678    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
679        // If the projection does not narrow the schema, we should not try to push it down.
680        if projection.expr().len() >= projection.input().schema().fields().len() {
681            return Ok(None);
682        }
683
684        // If pushdown is not beneficial or applicable, break it.
685        if projection.benefits_from_input_partitioning()[0]
686            || !all_columns(projection.expr())
687        {
688            return Ok(None);
689        }
690
691        let new_projection = make_with_child(projection, self.input())?;
692
693        let new_partitioning = match self.partitioning() {
694            Partitioning::Hash(partitions, size) => {
695                let mut new_partitions = vec![];
696                for partition in partitions {
697                    let Some(new_partition) =
698                        update_expr(partition, projection.expr(), false)?
699                    else {
700                        return Ok(None);
701                    };
702                    new_partitions.push(new_partition);
703                }
704                Partitioning::Hash(new_partitions, *size)
705            }
706            others => others.clone(),
707        };
708
709        Ok(Some(Arc::new(RepartitionExec::try_new(
710            new_projection,
711            new_partitioning,
712        )?)))
713    }
714}
715
716impl RepartitionExec {
717    /// Create a new RepartitionExec, that produces output `partitioning`, and
718    /// does not preserve the order of the input (see [`Self::with_preserve_order`]
719    /// for more details)
720    pub fn try_new(
721        input: Arc<dyn ExecutionPlan>,
722        partitioning: Partitioning,
723    ) -> Result<Self> {
724        let preserve_order = false;
725        let cache =
726            Self::compute_properties(&input, partitioning.clone(), preserve_order);
727        Ok(RepartitionExec {
728            input,
729            state: Default::default(),
730            metrics: ExecutionPlanMetricsSet::new(),
731            preserve_order,
732            cache,
733        })
734    }
735
736    fn maintains_input_order_helper(
737        input: &Arc<dyn ExecutionPlan>,
738        preserve_order: bool,
739    ) -> Vec<bool> {
740        // We preserve ordering when repartition is order preserving variant or input partitioning is 1
741        vec![preserve_order || input.output_partitioning().partition_count() <= 1]
742    }
743
744    fn eq_properties_helper(
745        input: &Arc<dyn ExecutionPlan>,
746        preserve_order: bool,
747    ) -> EquivalenceProperties {
748        // Equivalence Properties
749        let mut eq_properties = input.equivalence_properties().clone();
750        // If the ordering is lost, reset the ordering equivalence class:
751        if !Self::maintains_input_order_helper(input, preserve_order)[0] {
752            eq_properties.clear_orderings();
753        }
754        // When there are more than one input partitions, they will be fused at the output.
755        // Therefore, remove per partition constants.
756        if input.output_partitioning().partition_count() > 1 {
757            eq_properties.clear_per_partition_constants();
758        }
759        eq_properties
760    }
761
762    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
763    fn compute_properties(
764        input: &Arc<dyn ExecutionPlan>,
765        partitioning: Partitioning,
766        preserve_order: bool,
767    ) -> PlanProperties {
768        PlanProperties::new(
769            Self::eq_properties_helper(input, preserve_order),
770            partitioning,
771            input.pipeline_behavior(),
772            input.boundedness(),
773        )
774    }
775
776    /// Specify if this repartitioning operation should preserve the order of
777    /// rows from its input when producing output. Preserving order is more
778    /// expensive at runtime, so should only be set if the output of this
779    /// operator can take advantage of it.
780    ///
781    /// If the input is not ordered, or has only one partition, this is a no op,
782    /// and the node remains a `RepartitionExec`.
783    pub fn with_preserve_order(mut self) -> Self {
784        self.preserve_order =
785                // If the input isn't ordered, there is no ordering to preserve
786                self.input.output_ordering().is_some() &&
787                // if there is only one input partition, merging is not required
788                // to maintain order
789                self.input.output_partitioning().partition_count() > 1;
790        let eq_properties = Self::eq_properties_helper(&self.input, self.preserve_order);
791        self.cache = self.cache.with_eq_properties(eq_properties);
792        self
793    }
794
795    /// Return the sort expressions that are used to merge
796    fn sort_exprs(&self) -> Option<&LexOrdering> {
797        if self.preserve_order {
798            self.input.output_ordering()
799        } else {
800            None
801        }
802    }
803
804    /// Pulls data from the specified input plan, feeding it to the
805    /// output partitions based on the desired partitioning
806    ///
807    /// txs hold the output sending channels for each output partition
808    async fn pull_from_input(
809        input: Arc<dyn ExecutionPlan>,
810        partition: usize,
811        mut output_channels: HashMap<
812            usize,
813            (DistributionSender<MaybeBatch>, SharedMemoryReservation),
814        >,
815        partitioning: Partitioning,
816        metrics: RepartitionMetrics,
817        context: Arc<TaskContext>,
818    ) -> Result<()> {
819        let mut partitioner =
820            BatchPartitioner::try_new(partitioning, metrics.repartition_time.clone())?;
821
822        // execute the child operator
823        let timer = metrics.fetch_time.timer();
824        let mut stream = input.execute(partition, context)?;
825        timer.done();
826
827        // While there are still outputs to send to, keep pulling inputs
828        let mut batches_until_yield = partitioner.num_partitions();
829        while !output_channels.is_empty() {
830            // fetch the next batch
831            let timer = metrics.fetch_time.timer();
832            let result = stream.next().await;
833            timer.done();
834
835            // Input is done
836            let batch = match result {
837                Some(result) => result?,
838                None => break,
839            };
840
841            for res in partitioner.partition_iter(batch)? {
842                let (partition, batch) = res?;
843                let size = batch.get_array_memory_size();
844
845                let timer = metrics.send_time[partition].timer();
846                // if there is still a receiver, send to it
847                if let Some((tx, reservation)) = output_channels.get_mut(&partition) {
848                    reservation.lock().try_grow(size)?;
849
850                    if tx.send(Some(Ok(batch))).await.is_err() {
851                        // If the other end has hung up, it was an early shutdown (e.g. LIMIT)
852                        reservation.lock().shrink(size);
853                        output_channels.remove(&partition);
854                    }
855                }
856                timer.done();
857            }
858
859            // If the input stream is endless, we may spin forever and
860            // never yield back to tokio.  See
861            // https://github.com/apache/datafusion/issues/5278.
862            //
863            // However, yielding on every batch causes a bottleneck
864            // when running with multiple cores. See
865            // https://github.com/apache/datafusion/issues/6290
866            //
867            // Thus, heuristically yield after producing num_partition
868            // batches
869            //
870            // In round robin this is ideal as each input will get a
871            // new batch. In hash partitioning it may yield too often
872            // on uneven distributions even if some partition can not
873            // make progress, but parallelism is going to be limited
874            // in that case anyways
875            if batches_until_yield == 0 {
876                tokio::task::yield_now().await;
877                batches_until_yield = partitioner.num_partitions();
878            } else {
879                batches_until_yield -= 1;
880            }
881        }
882
883        Ok(())
884    }
885
886    /// Waits for `input_task` which is consuming one of the inputs to
887    /// complete. Upon each successful completion, sends a `None` to
888    /// each of the output tx channels to signal one of the inputs is
889    /// complete. Upon error, propagates the errors to all output tx
890    /// channels.
891    async fn wait_for_task(
892        input_task: SpawnedTask<Result<()>>,
893        txs: HashMap<usize, DistributionSender<MaybeBatch>>,
894    ) {
895        // wait for completion, and propagate error
896        // note we ignore errors on send (.ok) as that means the receiver has already shutdown.
897
898        match input_task.join().await {
899            // Error in joining task
900            Err(e) => {
901                let e = Arc::new(e);
902
903                for (_, tx) in txs {
904                    let err = Err(DataFusionError::Context(
905                        "Join Error".to_string(),
906                        Box::new(DataFusionError::External(Box::new(Arc::clone(&e)))),
907                    ));
908                    tx.send(Some(err)).await.ok();
909                }
910            }
911            // Error from running input task
912            Ok(Err(e)) => {
913                // send the same Arc'd error to all output partitions
914                let e = Arc::new(e);
915
916                for (_, tx) in txs {
917                    // wrap it because need to send error to all output partitions
918                    let err = Err(DataFusionError::from(&e));
919                    tx.send(Some(err)).await.ok();
920                }
921            }
922            // Input task completed successfully
923            Ok(Ok(())) => {
924                // notify each output partition that this input partition has no more data
925                for (_, tx) in txs {
926                    tx.send(None).await.ok();
927                }
928            }
929        }
930    }
931}
932
933struct RepartitionStream {
934    /// Number of input partitions that will be sending batches to this output channel
935    num_input_partitions: usize,
936
937    /// Number of input partitions that have finished sending batches to this output channel
938    num_input_partitions_processed: usize,
939
940    /// Schema wrapped by Arc
941    schema: SchemaRef,
942
943    /// channel containing the repartitioned batches
944    input: DistributionReceiver<MaybeBatch>,
945
946    /// Handle to ensure background tasks are killed when no longer needed.
947    _drop_helper: Arc<Vec<SpawnedTask<()>>>,
948
949    /// Memory reservation.
950    reservation: SharedMemoryReservation,
951}
952
953impl Stream for RepartitionStream {
954    type Item = Result<RecordBatch>;
955
956    fn poll_next(
957        mut self: Pin<&mut Self>,
958        cx: &mut Context<'_>,
959    ) -> Poll<Option<Self::Item>> {
960        loop {
961            match self.input.recv().poll_unpin(cx) {
962                Poll::Ready(Some(Some(v))) => {
963                    if let Ok(batch) = &v {
964                        self.reservation
965                            .lock()
966                            .shrink(batch.get_array_memory_size());
967                    }
968
969                    return Poll::Ready(Some(v));
970                }
971                Poll::Ready(Some(None)) => {
972                    self.num_input_partitions_processed += 1;
973
974                    if self.num_input_partitions == self.num_input_partitions_processed {
975                        // all input partitions have finished sending batches
976                        return Poll::Ready(None);
977                    } else {
978                        // other partitions still have data to send
979                        continue;
980                    }
981                }
982                Poll::Ready(None) => {
983                    return Poll::Ready(None);
984                }
985                Poll::Pending => {
986                    return Poll::Pending;
987                }
988            }
989        }
990    }
991}
992
993impl RecordBatchStream for RepartitionStream {
994    /// Get the schema
995    fn schema(&self) -> SchemaRef {
996        Arc::clone(&self.schema)
997    }
998}
999
1000/// This struct converts a receiver to a stream.
1001/// Receiver receives data on an SPSC channel.
1002struct PerPartitionStream {
1003    /// Schema wrapped by Arc
1004    schema: SchemaRef,
1005
1006    /// channel containing the repartitioned batches
1007    receiver: DistributionReceiver<MaybeBatch>,
1008
1009    /// Handle to ensure background tasks are killed when no longer needed.
1010    _drop_helper: Arc<Vec<SpawnedTask<()>>>,
1011
1012    /// Memory reservation.
1013    reservation: SharedMemoryReservation,
1014}
1015
1016impl Stream for PerPartitionStream {
1017    type Item = Result<RecordBatch>;
1018
1019    fn poll_next(
1020        mut self: Pin<&mut Self>,
1021        cx: &mut Context<'_>,
1022    ) -> Poll<Option<Self::Item>> {
1023        match self.receiver.recv().poll_unpin(cx) {
1024            Poll::Ready(Some(Some(v))) => {
1025                if let Ok(batch) = &v {
1026                    self.reservation
1027                        .lock()
1028                        .shrink(batch.get_array_memory_size());
1029                }
1030                Poll::Ready(Some(v))
1031            }
1032            Poll::Ready(Some(None)) => {
1033                // Input partition has finished sending batches
1034                Poll::Ready(None)
1035            }
1036            Poll::Ready(None) => Poll::Ready(None),
1037            Poll::Pending => Poll::Pending,
1038        }
1039    }
1040}
1041
1042impl RecordBatchStream for PerPartitionStream {
1043    /// Get the schema
1044    fn schema(&self) -> SchemaRef {
1045        Arc::clone(&self.schema)
1046    }
1047}
1048
1049#[cfg(test)]
1050mod tests {
1051    use std::collections::HashSet;
1052
1053    use super::*;
1054    use crate::test::TestMemoryExec;
1055    use crate::{
1056        test::{
1057            assert_is_pending,
1058            exec::{
1059                assert_strong_count_converges_to_zero, BarrierExec, BlockingExec,
1060                ErrorExec, MockExec,
1061            },
1062        },
1063        {collect, expressions::col},
1064    };
1065
1066    use arrow::array::{ArrayRef, StringArray, UInt32Array};
1067    use arrow::datatypes::{DataType, Field, Schema};
1068    use datafusion_common::cast::as_string_array;
1069    use datafusion_common::{arrow_datafusion_err, assert_batches_sorted_eq, exec_err};
1070    use datafusion_execution::runtime_env::RuntimeEnvBuilder;
1071
1072    use tokio::task::JoinSet;
1073
1074    #[tokio::test]
1075    async fn one_to_many_round_robin() -> Result<()> {
1076        // define input partitions
1077        let schema = test_schema();
1078        let partition = create_vec_batches(50);
1079        let partitions = vec![partition];
1080
1081        // repartition from 1 input to 4 output
1082        let output_partitions =
1083            repartition(&schema, partitions, Partitioning::RoundRobinBatch(4)).await?;
1084
1085        assert_eq!(4, output_partitions.len());
1086        assert_eq!(13, output_partitions[0].len());
1087        assert_eq!(13, output_partitions[1].len());
1088        assert_eq!(12, output_partitions[2].len());
1089        assert_eq!(12, output_partitions[3].len());
1090
1091        Ok(())
1092    }
1093
1094    #[tokio::test]
1095    async fn many_to_one_round_robin() -> Result<()> {
1096        // define input partitions
1097        let schema = test_schema();
1098        let partition = create_vec_batches(50);
1099        let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
1100
1101        // repartition from 3 input to 1 output
1102        let output_partitions =
1103            repartition(&schema, partitions, Partitioning::RoundRobinBatch(1)).await?;
1104
1105        assert_eq!(1, output_partitions.len());
1106        assert_eq!(150, output_partitions[0].len());
1107
1108        Ok(())
1109    }
1110
1111    #[tokio::test]
1112    async fn many_to_many_round_robin() -> Result<()> {
1113        // define input partitions
1114        let schema = test_schema();
1115        let partition = create_vec_batches(50);
1116        let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
1117
1118        // repartition from 3 input to 5 output
1119        let output_partitions =
1120            repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await?;
1121
1122        assert_eq!(5, output_partitions.len());
1123        assert_eq!(30, output_partitions[0].len());
1124        assert_eq!(30, output_partitions[1].len());
1125        assert_eq!(30, output_partitions[2].len());
1126        assert_eq!(30, output_partitions[3].len());
1127        assert_eq!(30, output_partitions[4].len());
1128
1129        Ok(())
1130    }
1131
1132    #[tokio::test]
1133    async fn many_to_many_hash_partition() -> Result<()> {
1134        // define input partitions
1135        let schema = test_schema();
1136        let partition = create_vec_batches(50);
1137        let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
1138
1139        let output_partitions = repartition(
1140            &schema,
1141            partitions,
1142            Partitioning::Hash(vec![col("c0", &schema)?], 8),
1143        )
1144        .await?;
1145
1146        let total_rows: usize = output_partitions
1147            .iter()
1148            .map(|x| x.iter().map(|x| x.num_rows()).sum::<usize>())
1149            .sum();
1150
1151        assert_eq!(8, output_partitions.len());
1152        assert_eq!(total_rows, 8 * 50 * 3);
1153
1154        Ok(())
1155    }
1156
1157    fn test_schema() -> Arc<Schema> {
1158        Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]))
1159    }
1160
1161    async fn repartition(
1162        schema: &SchemaRef,
1163        input_partitions: Vec<Vec<RecordBatch>>,
1164        partitioning: Partitioning,
1165    ) -> Result<Vec<Vec<RecordBatch>>> {
1166        let task_ctx = Arc::new(TaskContext::default());
1167        // create physical plan
1168        let exec =
1169            TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(schema), None)?;
1170        let exec = RepartitionExec::try_new(exec, partitioning)?;
1171
1172        // execute and collect results
1173        let mut output_partitions = vec![];
1174        for i in 0..exec.partitioning().partition_count() {
1175            // execute this *output* partition and collect all batches
1176            let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
1177            let mut batches = vec![];
1178            while let Some(result) = stream.next().await {
1179                batches.push(result?);
1180            }
1181            output_partitions.push(batches);
1182        }
1183        Ok(output_partitions)
1184    }
1185
1186    #[tokio::test]
1187    async fn many_to_many_round_robin_within_tokio_task() -> Result<()> {
1188        let handle: SpawnedTask<Result<Vec<Vec<RecordBatch>>>> =
1189            SpawnedTask::spawn(async move {
1190                // define input partitions
1191                let schema = test_schema();
1192                let partition = create_vec_batches(50);
1193                let partitions =
1194                    vec![partition.clone(), partition.clone(), partition.clone()];
1195
1196                // repartition from 3 input to 5 output
1197                repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await
1198            });
1199
1200        let output_partitions = handle.join().await.unwrap().unwrap();
1201
1202        assert_eq!(5, output_partitions.len());
1203        assert_eq!(30, output_partitions[0].len());
1204        assert_eq!(30, output_partitions[1].len());
1205        assert_eq!(30, output_partitions[2].len());
1206        assert_eq!(30, output_partitions[3].len());
1207        assert_eq!(30, output_partitions[4].len());
1208
1209        Ok(())
1210    }
1211
1212    #[tokio::test]
1213    async fn unsupported_partitioning() {
1214        let task_ctx = Arc::new(TaskContext::default());
1215        // have to send at least one batch through to provoke error
1216        let batch = RecordBatch::try_from_iter(vec![(
1217            "my_awesome_field",
1218            Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
1219        )])
1220        .unwrap();
1221
1222        let schema = batch.schema();
1223        let input = MockExec::new(vec![Ok(batch)], schema);
1224        // This generates an error (partitioning type not supported)
1225        // but only after the plan is executed. The error should be
1226        // returned and no results produced
1227        let partitioning = Partitioning::UnknownPartitioning(1);
1228        let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1229        let output_stream = exec.execute(0, task_ctx).unwrap();
1230
1231        // Expect that an error is returned
1232        let result_string = crate::common::collect(output_stream)
1233            .await
1234            .unwrap_err()
1235            .to_string();
1236        assert!(
1237            result_string
1238                .contains("Unsupported repartitioning scheme UnknownPartitioning(1)"),
1239            "actual: {result_string}"
1240        );
1241    }
1242
1243    #[tokio::test]
1244    async fn error_for_input_exec() {
1245        // This generates an error on a call to execute. The error
1246        // should be returned and no results produced.
1247
1248        let task_ctx = Arc::new(TaskContext::default());
1249        let input = ErrorExec::new();
1250        let partitioning = Partitioning::RoundRobinBatch(1);
1251        let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1252
1253        // Note: this should pass (the stream can be created) but the
1254        // error when the input is executed should get passed back
1255        let output_stream = exec.execute(0, task_ctx).unwrap();
1256
1257        // Expect that an error is returned
1258        let result_string = crate::common::collect(output_stream)
1259            .await
1260            .unwrap_err()
1261            .to_string();
1262        assert!(
1263            result_string.contains("ErrorExec, unsurprisingly, errored in partition 0"),
1264            "actual: {result_string}"
1265        );
1266    }
1267
1268    #[tokio::test]
1269    async fn repartition_with_error_in_stream() {
1270        let task_ctx = Arc::new(TaskContext::default());
1271        let batch = RecordBatch::try_from_iter(vec![(
1272            "my_awesome_field",
1273            Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
1274        )])
1275        .unwrap();
1276
1277        // input stream returns one good batch and then one error. The
1278        // error should be returned.
1279        let err = exec_err!("bad data error");
1280
1281        let schema = batch.schema();
1282        let input = MockExec::new(vec![Ok(batch), err], schema);
1283        let partitioning = Partitioning::RoundRobinBatch(1);
1284        let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1285
1286        // Note: this should pass (the stream can be created) but the
1287        // error when the input is executed should get passed back
1288        let output_stream = exec.execute(0, task_ctx).unwrap();
1289
1290        // Expect that an error is returned
1291        let result_string = crate::common::collect(output_stream)
1292            .await
1293            .unwrap_err()
1294            .to_string();
1295        assert!(
1296            result_string.contains("bad data error"),
1297            "actual: {result_string}"
1298        );
1299    }
1300
1301    #[tokio::test]
1302    async fn repartition_with_delayed_stream() {
1303        let task_ctx = Arc::new(TaskContext::default());
1304        let batch1 = RecordBatch::try_from_iter(vec![(
1305            "my_awesome_field",
1306            Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
1307        )])
1308        .unwrap();
1309
1310        let batch2 = RecordBatch::try_from_iter(vec![(
1311            "my_awesome_field",
1312            Arc::new(StringArray::from(vec!["frob", "baz"])) as ArrayRef,
1313        )])
1314        .unwrap();
1315
1316        // The mock exec doesn't return immediately (instead it
1317        // requires the input to wait at least once)
1318        let schema = batch1.schema();
1319        let expected_batches = vec![batch1.clone(), batch2.clone()];
1320        let input = MockExec::new(vec![Ok(batch1), Ok(batch2)], schema);
1321        let partitioning = Partitioning::RoundRobinBatch(1);
1322
1323        let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1324
1325        let expected = vec![
1326            "+------------------+",
1327            "| my_awesome_field |",
1328            "+------------------+",
1329            "| foo              |",
1330            "| bar              |",
1331            "| frob             |",
1332            "| baz              |",
1333            "+------------------+",
1334        ];
1335
1336        assert_batches_sorted_eq!(&expected, &expected_batches);
1337
1338        let output_stream = exec.execute(0, task_ctx).unwrap();
1339        let batches = crate::common::collect(output_stream).await.unwrap();
1340
1341        assert_batches_sorted_eq!(&expected, &batches);
1342    }
1343
1344    #[tokio::test]
1345    async fn robin_repartition_with_dropping_output_stream() {
1346        let task_ctx = Arc::new(TaskContext::default());
1347        let partitioning = Partitioning::RoundRobinBatch(2);
1348        // The barrier exec waits to be pinged
1349        // requires the input to wait at least once)
1350        let input = Arc::new(make_barrier_exec());
1351
1352        // partition into two output streams
1353        let exec = RepartitionExec::try_new(
1354            Arc::clone(&input) as Arc<dyn ExecutionPlan>,
1355            partitioning,
1356        )
1357        .unwrap();
1358
1359        let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
1360        let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
1361
1362        // now, purposely drop output stream 0
1363        // *before* any outputs are produced
1364        drop(output_stream0);
1365
1366        // Now, start sending input
1367        let mut background_task = JoinSet::new();
1368        background_task.spawn(async move {
1369            input.wait().await;
1370        });
1371
1372        // output stream 1 should *not* error and have one of the input batches
1373        let batches = crate::common::collect(output_stream1).await.unwrap();
1374
1375        let expected = vec![
1376            "+------------------+",
1377            "| my_awesome_field |",
1378            "+------------------+",
1379            "| baz              |",
1380            "| frob             |",
1381            "| gaz              |",
1382            "| grob             |",
1383            "+------------------+",
1384        ];
1385
1386        assert_batches_sorted_eq!(&expected, &batches);
1387    }
1388
1389    #[tokio::test]
1390    // As the hash results might be different on different platforms or
1391    // with different compilers, we will compare the same execution with
1392    // and without dropping the output stream.
1393    async fn hash_repartition_with_dropping_output_stream() {
1394        let task_ctx = Arc::new(TaskContext::default());
1395        let partitioning = Partitioning::Hash(
1396            vec![Arc::new(crate::expressions::Column::new(
1397                "my_awesome_field",
1398                0,
1399            ))],
1400            2,
1401        );
1402
1403        // We first collect the results without dropping the output stream.
1404        let input = Arc::new(make_barrier_exec());
1405        let exec = RepartitionExec::try_new(
1406            Arc::clone(&input) as Arc<dyn ExecutionPlan>,
1407            partitioning.clone(),
1408        )
1409        .unwrap();
1410        let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
1411        let mut background_task = JoinSet::new();
1412        background_task.spawn(async move {
1413            input.wait().await;
1414        });
1415        let batches_without_drop = crate::common::collect(output_stream1).await.unwrap();
1416
1417        // run some checks on the result
1418        let items_vec = str_batches_to_vec(&batches_without_drop);
1419        let items_set: HashSet<&str> = items_vec.iter().copied().collect();
1420        assert_eq!(items_vec.len(), items_set.len());
1421        let source_str_set: HashSet<&str> =
1422            ["foo", "bar", "frob", "baz", "goo", "gar", "grob", "gaz"]
1423                .iter()
1424                .copied()
1425                .collect();
1426        assert_eq!(items_set.difference(&source_str_set).count(), 0);
1427
1428        // Now do the same but dropping the stream before waiting for the barrier
1429        let input = Arc::new(make_barrier_exec());
1430        let exec = RepartitionExec::try_new(
1431            Arc::clone(&input) as Arc<dyn ExecutionPlan>,
1432            partitioning,
1433        )
1434        .unwrap();
1435        let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
1436        let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
1437        // now, purposely drop output stream 0
1438        // *before* any outputs are produced
1439        drop(output_stream0);
1440        let mut background_task = JoinSet::new();
1441        background_task.spawn(async move {
1442            input.wait().await;
1443        });
1444        let batches_with_drop = crate::common::collect(output_stream1).await.unwrap();
1445
1446        assert_eq!(batches_without_drop, batches_with_drop);
1447    }
1448
1449    fn str_batches_to_vec(batches: &[RecordBatch]) -> Vec<&str> {
1450        batches
1451            .iter()
1452            .flat_map(|batch| {
1453                assert_eq!(batch.columns().len(), 1);
1454                let string_array = as_string_array(batch.column(0))
1455                    .expect("Unexpected type for repartitioned batch");
1456
1457                string_array
1458                    .iter()
1459                    .map(|v| v.expect("Unexpected null"))
1460                    .collect::<Vec<_>>()
1461            })
1462            .collect::<Vec<_>>()
1463    }
1464
1465    /// Create a BarrierExec that returns two partitions of two batches each
1466    fn make_barrier_exec() -> BarrierExec {
1467        let batch1 = RecordBatch::try_from_iter(vec![(
1468            "my_awesome_field",
1469            Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
1470        )])
1471        .unwrap();
1472
1473        let batch2 = RecordBatch::try_from_iter(vec![(
1474            "my_awesome_field",
1475            Arc::new(StringArray::from(vec!["frob", "baz"])) as ArrayRef,
1476        )])
1477        .unwrap();
1478
1479        let batch3 = RecordBatch::try_from_iter(vec![(
1480            "my_awesome_field",
1481            Arc::new(StringArray::from(vec!["goo", "gar"])) as ArrayRef,
1482        )])
1483        .unwrap();
1484
1485        let batch4 = RecordBatch::try_from_iter(vec![(
1486            "my_awesome_field",
1487            Arc::new(StringArray::from(vec!["grob", "gaz"])) as ArrayRef,
1488        )])
1489        .unwrap();
1490
1491        // The barrier exec waits to be pinged
1492        // requires the input to wait at least once)
1493        let schema = batch1.schema();
1494        BarrierExec::new(vec![vec![batch1, batch2], vec![batch3, batch4]], schema)
1495    }
1496
1497    #[tokio::test]
1498    async fn test_drop_cancel() -> Result<()> {
1499        let task_ctx = Arc::new(TaskContext::default());
1500        let schema =
1501            Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]));
1502
1503        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 2));
1504        let refs = blocking_exec.refs();
1505        let repartition_exec = Arc::new(RepartitionExec::try_new(
1506            blocking_exec,
1507            Partitioning::UnknownPartitioning(1),
1508        )?);
1509
1510        let fut = collect(repartition_exec, task_ctx);
1511        let mut fut = fut.boxed();
1512
1513        assert_is_pending(&mut fut);
1514        drop(fut);
1515        assert_strong_count_converges_to_zero(refs).await;
1516
1517        Ok(())
1518    }
1519
1520    #[tokio::test]
1521    async fn hash_repartition_avoid_empty_batch() -> Result<()> {
1522        let task_ctx = Arc::new(TaskContext::default());
1523        let batch = RecordBatch::try_from_iter(vec![(
1524            "a",
1525            Arc::new(StringArray::from(vec!["foo"])) as ArrayRef,
1526        )])
1527        .unwrap();
1528        let partitioning = Partitioning::Hash(
1529            vec![Arc::new(crate::expressions::Column::new("a", 0))],
1530            2,
1531        );
1532        let schema = batch.schema();
1533        let input = MockExec::new(vec![Ok(batch)], schema);
1534        let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1535        let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
1536        let batch0 = crate::common::collect(output_stream0).await.unwrap();
1537        let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
1538        let batch1 = crate::common::collect(output_stream1).await.unwrap();
1539        assert!(batch0.is_empty() || batch1.is_empty());
1540        Ok(())
1541    }
1542
1543    #[tokio::test]
1544    async fn oom() -> Result<()> {
1545        // define input partitions
1546        let schema = test_schema();
1547        let partition = create_vec_batches(50);
1548        let input_partitions = vec![partition];
1549        let partitioning = Partitioning::RoundRobinBatch(4);
1550
1551        // setup up context
1552        let runtime = RuntimeEnvBuilder::default()
1553            .with_memory_limit(1, 1.0)
1554            .build_arc()?;
1555
1556        let task_ctx = TaskContext::default().with_runtime(runtime);
1557        let task_ctx = Arc::new(task_ctx);
1558
1559        // create physical plan
1560        let exec =
1561            TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
1562        let exec = RepartitionExec::try_new(exec, partitioning)?;
1563
1564        // pull partitions
1565        for i in 0..exec.partitioning().partition_count() {
1566            let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
1567            let err =
1568                arrow_datafusion_err!(stream.next().await.unwrap().unwrap_err().into());
1569            let err = err.find_root();
1570            assert!(
1571                matches!(err, DataFusionError::ResourcesExhausted(_)),
1572                "Wrong error type: {err}",
1573            );
1574        }
1575
1576        Ok(())
1577    }
1578
1579    /// Create vector batches
1580    fn create_vec_batches(n: usize) -> Vec<RecordBatch> {
1581        let batch = create_batch();
1582        (0..n).map(|_| batch.clone()).collect()
1583    }
1584
1585    /// Create batch
1586    fn create_batch() -> RecordBatch {
1587        let schema = test_schema();
1588        RecordBatch::try_new(
1589            schema,
1590            vec![Arc::new(UInt32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]))],
1591        )
1592        .unwrap()
1593    }
1594}
1595
1596#[cfg(test)]
1597mod test {
1598    use arrow::compute::SortOptions;
1599    use arrow::datatypes::{DataType, Field, Schema};
1600
1601    use super::*;
1602    use crate::test::TestMemoryExec;
1603    use crate::union::UnionExec;
1604
1605    use datafusion_physical_expr::expressions::col;
1606    use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
1607
1608    /// Asserts that the plan is as expected
1609    ///
1610    /// `$EXPECTED_PLAN_LINES`: input plan
1611    /// `$PLAN`: the plan to optimized
1612    ///
1613    macro_rules! assert_plan {
1614        ($EXPECTED_PLAN_LINES: expr,  $PLAN: expr) => {
1615            let physical_plan = $PLAN;
1616            let formatted = crate::displayable(&physical_plan).indent(true).to_string();
1617            let actual: Vec<&str> = formatted.trim().lines().collect();
1618
1619            let expected_plan_lines: Vec<&str> = $EXPECTED_PLAN_LINES
1620                .iter().map(|s| *s).collect();
1621
1622            assert_eq!(
1623                expected_plan_lines, actual,
1624                "\n**Original Plan Mismatch\n\nexpected:\n\n{expected_plan_lines:#?}\nactual:\n\n{actual:#?}\n\n"
1625            );
1626        };
1627    }
1628
1629    #[tokio::test]
1630    async fn test_preserve_order() -> Result<()> {
1631        let schema = test_schema();
1632        let sort_exprs = sort_exprs(&schema);
1633        let source1 = sorted_memory_exec(&schema, sort_exprs.clone());
1634        let source2 = sorted_memory_exec(&schema, sort_exprs);
1635        // output has multiple partitions, and is sorted
1636        let union = UnionExec::new(vec![source1, source2]);
1637        let exec =
1638            RepartitionExec::try_new(Arc::new(union), Partitioning::RoundRobinBatch(10))
1639                .unwrap()
1640                .with_preserve_order();
1641
1642        // Repartition should preserve order
1643        let expected_plan = [
1644            "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2, preserve_order=true, sort_exprs=c0@0 ASC",
1645            "  UnionExec",
1646            "    DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC",
1647            "    DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC",
1648        ];
1649        assert_plan!(expected_plan, exec);
1650        Ok(())
1651    }
1652
1653    #[tokio::test]
1654    async fn test_preserve_order_one_partition() -> Result<()> {
1655        let schema = test_schema();
1656        let sort_exprs = sort_exprs(&schema);
1657        let source = sorted_memory_exec(&schema, sort_exprs);
1658        // output is sorted, but has only a single partition, so no need to sort
1659        let exec = RepartitionExec::try_new(source, Partitioning::RoundRobinBatch(10))
1660            .unwrap()
1661            .with_preserve_order();
1662
1663        // Repartition should not preserve order
1664        let expected_plan = [
1665            "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1",
1666            "  DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC",
1667        ];
1668        assert_plan!(expected_plan, exec);
1669        Ok(())
1670    }
1671
1672    #[tokio::test]
1673    async fn test_preserve_order_input_not_sorted() -> Result<()> {
1674        let schema = test_schema();
1675        let source1 = memory_exec(&schema);
1676        let source2 = memory_exec(&schema);
1677        // output has multiple partitions, but is not sorted
1678        let union = UnionExec::new(vec![source1, source2]);
1679        let exec =
1680            RepartitionExec::try_new(Arc::new(union), Partitioning::RoundRobinBatch(10))
1681                .unwrap()
1682                .with_preserve_order();
1683
1684        // Repartition should not preserve order, as there is no order to preserve
1685        let expected_plan = [
1686            "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2",
1687            "  UnionExec",
1688            "    DataSourceExec: partitions=1, partition_sizes=[0]",
1689            "    DataSourceExec: partitions=1, partition_sizes=[0]",
1690        ];
1691        assert_plan!(expected_plan, exec);
1692        Ok(())
1693    }
1694
1695    fn test_schema() -> Arc<Schema> {
1696        Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]))
1697    }
1698
1699    fn sort_exprs(schema: &Schema) -> LexOrdering {
1700        let options = SortOptions::default();
1701        LexOrdering::new(vec![PhysicalSortExpr {
1702            expr: col("c0", schema).unwrap(),
1703            options,
1704        }])
1705    }
1706
1707    fn memory_exec(schema: &SchemaRef) -> Arc<dyn ExecutionPlan> {
1708        TestMemoryExec::try_new_exec(&[vec![]], Arc::clone(schema), None).unwrap()
1709    }
1710
1711    fn sorted_memory_exec(
1712        schema: &SchemaRef,
1713        sort_exprs: LexOrdering,
1714    ) -> Arc<dyn ExecutionPlan> {
1715        Arc::new(TestMemoryExec::update_cache(Arc::new(
1716            TestMemoryExec::try_new(&[vec![]], Arc::clone(schema), None)
1717                .unwrap()
1718                .try_with_sort_information(vec![sort_exprs])
1719                .unwrap(),
1720        )))
1721    }
1722}