datafusion_physical_plan/sorts/
sort_preserving_merge.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`SortPreservingMergeExec`] merges multiple sorted streams into one sorted stream.
19
20use std::any::Any;
21use std::sync::Arc;
22
23use crate::common::spawn_buffered;
24use crate::limit::LimitStream;
25use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
26use crate::projection::{make_with_child, update_expr, ProjectionExec};
27use crate::sorts::streaming_merge::StreamingMergeBuilder;
28use crate::{
29    DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
30    Partitioning, PlanProperties, SendableRecordBatchStream, Statistics,
31};
32
33use datafusion_common::{internal_err, Result};
34use datafusion_execution::memory_pool::MemoryConsumer;
35use datafusion_execution::TaskContext;
36use datafusion_physical_expr::PhysicalSortExpr;
37use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement};
38
39use log::{debug, trace};
40
41/// Sort preserving merge execution plan
42///
43/// # Overview
44///
45/// This operator implements a K-way merge. It is used to merge multiple sorted
46/// streams into a single sorted stream and is highly optimized.
47///
48/// ## Inputs:
49///
50/// 1. A list of sort expressions
51/// 2. An input plan, where each partition is sorted with respect to
52///    these sort expressions.
53///
54/// ## Output:
55///
56/// 1. A single partition that is also sorted with respect to the expressions
57///
58/// ## Diagram
59///
60/// ```text
61/// ┌─────────────────────────┐
62/// │ ┌───┬───┬───┬───┐       │
63/// │ │ A │ B │ C │ D │ ...   │──┐
64/// │ └───┴───┴───┴───┘       │  │
65/// └─────────────────────────┘  │  ┌───────────────────┐    ┌───────────────────────────────┐
66///   Stream 1                   │  │                   │    │ ┌───┬───╦═══╦───┬───╦═══╗     │
67///                              ├─▶│SortPreservingMerge│───▶│ │ A │ B ║ B ║ C │ D ║ E ║ ... │
68///                              │  │                   │    │ └───┴─▲─╩═══╩───┴───╩═══╝     │
69/// ┌─────────────────────────┐  │  └───────────────────┘    └─┬─────┴───────────────────────┘
70/// │ ╔═══╦═══╗               │  │
71/// │ ║ B ║ E ║     ...       │──┘                             │
72/// │ ╚═══╩═══╝               │              Stable sort if `enable_round_robin_repartition=false`:
73/// └─────────────────────────┘              the merged stream places equal rows from stream 1
74///   Stream 2
75///
76///
77///  Input Partitions                                          Output Partition
78///    (sorted)                                                  (sorted)
79/// ```
80///
81/// # Error Handling
82///
83/// If any of the input partitions return an error, the error is propagated to
84/// the output and inputs are not polled again.
85#[derive(Debug, Clone)]
86pub struct SortPreservingMergeExec {
87    /// Input plan with sorted partitions
88    input: Arc<dyn ExecutionPlan>,
89    /// Sort expressions
90    expr: LexOrdering,
91    /// Execution metrics
92    metrics: ExecutionPlanMetricsSet,
93    /// Optional number of rows to fetch. Stops producing rows after this fetch
94    fetch: Option<usize>,
95    /// Cache holding plan properties like equivalences, output partitioning etc.
96    cache: PlanProperties,
97    /// Use round-robin selection of tied winners of loser tree
98    ///
99    /// See [`Self::with_round_robin_repartition`] for more information.
100    enable_round_robin_repartition: bool,
101}
102
103impl SortPreservingMergeExec {
104    /// Create a new sort execution plan
105    pub fn new(expr: LexOrdering, input: Arc<dyn ExecutionPlan>) -> Self {
106        let cache = Self::compute_properties(&input, expr.clone());
107        Self {
108            input,
109            expr,
110            metrics: ExecutionPlanMetricsSet::new(),
111            fetch: None,
112            cache,
113            enable_round_robin_repartition: true,
114        }
115    }
116
117    /// Sets the number of rows to fetch
118    pub fn with_fetch(mut self, fetch: Option<usize>) -> Self {
119        self.fetch = fetch;
120        self
121    }
122
123    /// Sets the selection strategy of tied winners of the loser tree algorithm
124    ///
125    /// If true (the default) equal output rows are placed in the merged stream
126    /// in round robin fashion. This approach consumes input streams at more
127    /// even rates when there are many rows with the same sort key.
128    ///
129    /// If false, equal output rows are always placed in the merged stream in
130    /// the order of the inputs, resulting in potentially slower execution but a
131    /// stable output order.
132    pub fn with_round_robin_repartition(
133        mut self,
134        enable_round_robin_repartition: bool,
135    ) -> Self {
136        self.enable_round_robin_repartition = enable_round_robin_repartition;
137        self
138    }
139
140    /// Input schema
141    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
142        &self.input
143    }
144
145    /// Sort expressions
146    pub fn expr(&self) -> &LexOrdering {
147        self.expr.as_ref()
148    }
149
150    /// Fetch
151    pub fn fetch(&self) -> Option<usize> {
152        self.fetch
153    }
154
155    /// Creates the cache object that stores the plan properties
156    /// such as schema, equivalence properties, ordering, partitioning, etc.
157    fn compute_properties(
158        input: &Arc<dyn ExecutionPlan>,
159        ordering: LexOrdering,
160    ) -> PlanProperties {
161        let mut eq_properties = input.equivalence_properties().clone();
162        eq_properties.clear_per_partition_constants();
163        eq_properties.add_new_orderings(vec![ordering]);
164        PlanProperties::new(
165            eq_properties,                        // Equivalence Properties
166            Partitioning::UnknownPartitioning(1), // Output Partitioning
167            input.pipeline_behavior(),            // Pipeline Behavior
168            input.boundedness(),                  // Boundedness
169        )
170    }
171}
172
173impl DisplayAs for SortPreservingMergeExec {
174    fn fmt_as(
175        &self,
176        t: DisplayFormatType,
177        f: &mut std::fmt::Formatter,
178    ) -> std::fmt::Result {
179        match t {
180            DisplayFormatType::Default | DisplayFormatType::Verbose => {
181                write!(f, "SortPreservingMergeExec: [{}]", self.expr)?;
182                if let Some(fetch) = self.fetch {
183                    write!(f, ", fetch={fetch}")?;
184                };
185
186                Ok(())
187            }
188        }
189    }
190}
191
192impl ExecutionPlan for SortPreservingMergeExec {
193    fn name(&self) -> &'static str {
194        "SortPreservingMergeExec"
195    }
196
197    /// Return a reference to Any that can be used for downcasting
198    fn as_any(&self) -> &dyn Any {
199        self
200    }
201
202    fn properties(&self) -> &PlanProperties {
203        &self.cache
204    }
205
206    fn fetch(&self) -> Option<usize> {
207        self.fetch
208    }
209
210    /// Sets the number of rows to fetch
211    fn with_fetch(&self, limit: Option<usize>) -> Option<Arc<dyn ExecutionPlan>> {
212        Some(Arc::new(Self {
213            input: Arc::clone(&self.input),
214            expr: self.expr.clone(),
215            metrics: self.metrics.clone(),
216            fetch: limit,
217            cache: self.cache.clone(),
218            enable_round_robin_repartition: true,
219        }))
220    }
221
222    fn required_input_distribution(&self) -> Vec<Distribution> {
223        vec![Distribution::UnspecifiedDistribution]
224    }
225
226    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
227        vec![false]
228    }
229
230    fn required_input_ordering(&self) -> Vec<Option<LexRequirement>> {
231        vec![Some(LexRequirement::from(self.expr.clone()))]
232    }
233
234    fn maintains_input_order(&self) -> Vec<bool> {
235        vec![true]
236    }
237
238    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
239        vec![&self.input]
240    }
241
242    fn with_new_children(
243        self: Arc<Self>,
244        children: Vec<Arc<dyn ExecutionPlan>>,
245    ) -> Result<Arc<dyn ExecutionPlan>> {
246        Ok(Arc::new(
247            SortPreservingMergeExec::new(self.expr.clone(), Arc::clone(&children[0]))
248                .with_fetch(self.fetch),
249        ))
250    }
251
252    fn execute(
253        &self,
254        partition: usize,
255        context: Arc<TaskContext>,
256    ) -> Result<SendableRecordBatchStream> {
257        trace!(
258            "Start SortPreservingMergeExec::execute for partition: {}",
259            partition
260        );
261        if 0 != partition {
262            return internal_err!(
263                "SortPreservingMergeExec invalid partition {partition}"
264            );
265        }
266
267        let input_partitions = self.input.output_partitioning().partition_count();
268        trace!(
269            "Number of input partitions of  SortPreservingMergeExec::execute: {}",
270            input_partitions
271        );
272        let schema = self.schema();
273
274        let reservation =
275            MemoryConsumer::new(format!("SortPreservingMergeExec[{partition}]"))
276                .register(&context.runtime_env().memory_pool);
277
278        match input_partitions {
279            0 => internal_err!(
280                "SortPreservingMergeExec requires at least one input partition"
281            ),
282            1 => match self.fetch {
283                Some(fetch) => {
284                    let stream = self.input.execute(0, context)?;
285                    debug!("Done getting stream for SortPreservingMergeExec::execute with 1 input with {fetch}");
286                    Ok(Box::pin(LimitStream::new(
287                        stream,
288                        0,
289                        Some(fetch),
290                        BaselineMetrics::new(&self.metrics, partition),
291                    )))
292                }
293                None => {
294                    let stream = self.input.execute(0, context);
295                    debug!("Done getting stream for SortPreservingMergeExec::execute with 1 input without fetch");
296                    stream
297                }
298            },
299            _ => {
300                let receivers = (0..input_partitions)
301                    .map(|partition| {
302                        let stream =
303                            self.input.execute(partition, Arc::clone(&context))?;
304                        Ok(spawn_buffered(stream, 1))
305                    })
306                    .collect::<Result<_>>()?;
307
308                debug!("Done setting up sender-receiver for SortPreservingMergeExec::execute");
309
310                let result = StreamingMergeBuilder::new()
311                    .with_streams(receivers)
312                    .with_schema(schema)
313                    .with_expressions(self.expr.as_ref())
314                    .with_metrics(BaselineMetrics::new(&self.metrics, partition))
315                    .with_batch_size(context.session_config().batch_size())
316                    .with_fetch(self.fetch)
317                    .with_reservation(reservation)
318                    .with_round_robin_tie_breaker(self.enable_round_robin_repartition)
319                    .build()?;
320
321                debug!("Got stream result from SortPreservingMergeStream::new_from_receivers");
322
323                Ok(result)
324            }
325        }
326    }
327
328    fn metrics(&self) -> Option<MetricsSet> {
329        Some(self.metrics.clone_inner())
330    }
331
332    fn statistics(&self) -> Result<Statistics> {
333        self.input.statistics()
334    }
335
336    fn supports_limit_pushdown(&self) -> bool {
337        true
338    }
339
340    /// Tries to swap the projection with its input [`SortPreservingMergeExec`].
341    /// If this is possible, it returns the new [`SortPreservingMergeExec`] whose
342    /// child is a projection. Otherwise, it returns None.
343    fn try_swapping_with_projection(
344        &self,
345        projection: &ProjectionExec,
346    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
347        // If the projection does not narrow the schema, we should not try to push it down.
348        if projection.expr().len() >= projection.input().schema().fields().len() {
349            return Ok(None);
350        }
351
352        let mut updated_exprs = LexOrdering::default();
353        for sort in self.expr() {
354            let Some(updated_expr) = update_expr(&sort.expr, projection.expr(), false)?
355            else {
356                return Ok(None);
357            };
358            updated_exprs.push(PhysicalSortExpr {
359                expr: updated_expr,
360                options: sort.options,
361            });
362        }
363
364        Ok(Some(Arc::new(
365            SortPreservingMergeExec::new(
366                updated_exprs,
367                make_with_child(projection, self.input())?,
368            )
369            .with_fetch(self.fetch()),
370        )))
371    }
372}
373
374#[cfg(test)]
375mod tests {
376    use std::fmt::Formatter;
377    use std::pin::Pin;
378    use std::sync::Mutex;
379    use std::task::{Context, Poll};
380    use std::time::Duration;
381
382    use super::*;
383    use crate::coalesce_batches::CoalesceBatchesExec;
384    use crate::coalesce_partitions::CoalescePartitionsExec;
385    use crate::execution_plan::{Boundedness, EmissionType};
386    use crate::expressions::col;
387    use crate::metrics::{MetricValue, Timestamp};
388    use crate::repartition::RepartitionExec;
389    use crate::sorts::sort::SortExec;
390    use crate::stream::RecordBatchReceiverStream;
391    use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec};
392    use crate::test::TestMemoryExec;
393    use crate::test::{self, assert_is_pending, make_partition};
394    use crate::{collect, common};
395
396    use arrow::array::{
397        ArrayRef, Int32Array, Int64Array, RecordBatch, StringArray,
398        TimestampNanosecondArray,
399    };
400    use arrow::compute::SortOptions;
401    use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
402    use datafusion_common::{assert_batches_eq, assert_contains, DataFusionError};
403    use datafusion_common_runtime::SpawnedTask;
404    use datafusion_execution::config::SessionConfig;
405    use datafusion_execution::runtime_env::RuntimeEnvBuilder;
406    use datafusion_execution::RecordBatchStream;
407    use datafusion_physical_expr::expressions::Column;
408    use datafusion_physical_expr::EquivalenceProperties;
409    use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
410
411    use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
412    use futures::{FutureExt, Stream, StreamExt};
413    use tokio::time::timeout;
414
415    // The number in the function is highly related to the memory limit we are testing
416    // any change of the constant should be aware of
417    fn generate_task_ctx_for_round_robin_tie_breaker() -> Result<Arc<TaskContext>> {
418        let runtime = RuntimeEnvBuilder::new()
419            .with_memory_limit(20_000_000, 1.0)
420            .build_arc()?;
421        let config = SessionConfig::new();
422        let task_ctx = TaskContext::default()
423            .with_runtime(runtime)
424            .with_session_config(config);
425        Ok(Arc::new(task_ctx))
426    }
427    // The number in the function is highly related to the memory limit we are testing,
428    // any change of the constant should be aware of
429    fn generate_spm_for_round_robin_tie_breaker(
430        enable_round_robin_repartition: bool,
431    ) -> Result<Arc<SortPreservingMergeExec>> {
432        let target_batch_size = 12500;
433        let row_size = 12500;
434        let a: ArrayRef = Arc::new(Int32Array::from(vec![1; row_size]));
435        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("a"); row_size]));
436        let c: ArrayRef = Arc::new(Int64Array::from_iter(vec![0; row_size]));
437        let rb = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
438
439        let rbs = (0..1024).map(|_| rb.clone()).collect::<Vec<_>>();
440
441        let schema = rb.schema();
442        let sort = LexOrdering::new(vec![
443            PhysicalSortExpr {
444                expr: col("b", &schema).unwrap(),
445                options: Default::default(),
446            },
447            PhysicalSortExpr {
448                expr: col("c", &schema).unwrap(),
449                options: Default::default(),
450            },
451        ]);
452
453        let repartition_exec = RepartitionExec::try_new(
454            TestMemoryExec::try_new_exec(&[rbs], schema, None).unwrap(),
455            Partitioning::RoundRobinBatch(2),
456        )?;
457        let coalesce_batches_exec =
458            CoalesceBatchesExec::new(Arc::new(repartition_exec), target_batch_size);
459        let spm = SortPreservingMergeExec::new(sort, Arc::new(coalesce_batches_exec))
460            .with_round_robin_repartition(enable_round_robin_repartition);
461        Ok(Arc::new(spm))
462    }
463
464    /// This test verifies that memory usage stays within limits when the tie breaker is enabled.
465    /// Any errors here could indicate unintended changes in tie breaker logic.
466    ///
467    /// Note: If you adjust constants in this test, ensure that memory usage differs
468    /// based on whether the tie breaker is enabled or disabled.
469    #[tokio::test(flavor = "multi_thread")]
470    async fn test_round_robin_tie_breaker_success() -> Result<()> {
471        let task_ctx = generate_task_ctx_for_round_robin_tie_breaker()?;
472        let spm = generate_spm_for_round_robin_tie_breaker(true)?;
473        let _collected = collect(spm, task_ctx).await.unwrap();
474        Ok(())
475    }
476
477    /// This test verifies that memory usage stays within limits when the tie breaker is enabled.
478    /// Any errors here could indicate unintended changes in tie breaker logic.
479    ///
480    /// Note: If you adjust constants in this test, ensure that memory usage differs
481    /// based on whether the tie breaker is enabled or disabled.
482    #[tokio::test(flavor = "multi_thread")]
483    async fn test_round_robin_tie_breaker_fail() -> Result<()> {
484        let task_ctx = generate_task_ctx_for_round_robin_tie_breaker()?;
485        let spm = generate_spm_for_round_robin_tie_breaker(false)?;
486        let _err = collect(spm, task_ctx).await.unwrap_err();
487        Ok(())
488    }
489
490    #[tokio::test]
491    async fn test_merge_interleave() {
492        let task_ctx = Arc::new(TaskContext::default());
493        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
494        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
495            Some("a"),
496            Some("c"),
497            Some("e"),
498            Some("g"),
499            Some("j"),
500        ]));
501        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
502        let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
503
504        let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 70, 90, 30]));
505        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
506            Some("b"),
507            Some("d"),
508            Some("f"),
509            Some("h"),
510            Some("j"),
511        ]));
512        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
513        let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
514
515        _test_merge(
516            &[vec![b1], vec![b2]],
517            &[
518                "+----+---+-------------------------------+",
519                "| a  | b | c                             |",
520                "+----+---+-------------------------------+",
521                "| 1  | a | 1970-01-01T00:00:00.000000008 |",
522                "| 10 | b | 1970-01-01T00:00:00.000000004 |",
523                "| 2  | c | 1970-01-01T00:00:00.000000007 |",
524                "| 20 | d | 1970-01-01T00:00:00.000000006 |",
525                "| 7  | e | 1970-01-01T00:00:00.000000006 |",
526                "| 70 | f | 1970-01-01T00:00:00.000000002 |",
527                "| 9  | g | 1970-01-01T00:00:00.000000005 |",
528                "| 90 | h | 1970-01-01T00:00:00.000000002 |",
529                "| 30 | j | 1970-01-01T00:00:00.000000006 |", // input b2 before b1
530                "| 3  | j | 1970-01-01T00:00:00.000000008 |",
531                "+----+---+-------------------------------+",
532            ],
533            task_ctx,
534        )
535        .await;
536    }
537
538    #[tokio::test]
539    async fn test_merge_no_exprs() {
540        let task_ctx = Arc::new(TaskContext::default());
541        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
542        let batch = RecordBatch::try_from_iter(vec![("a", a)]).unwrap();
543
544        let schema = batch.schema();
545        let sort = LexOrdering::default(); // no sort expressions
546        let exec = TestMemoryExec::try_new_exec(
547            &[vec![batch.clone()], vec![batch]],
548            schema,
549            None,
550        )
551        .unwrap();
552
553        let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
554
555        let res = collect(merge, task_ctx).await.unwrap_err();
556        assert_contains!(
557            res.to_string(),
558            "Internal error: Sort expressions cannot be empty for streaming merge"
559        );
560    }
561
562    #[tokio::test]
563    async fn test_merge_some_overlap() {
564        let task_ctx = Arc::new(TaskContext::default());
565        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
566        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
567            Some("a"),
568            Some("b"),
569            Some("c"),
570            Some("d"),
571            Some("e"),
572        ]));
573        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
574        let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
575
576        let a: ArrayRef = Arc::new(Int32Array::from(vec![70, 90, 30, 100, 110]));
577        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
578            Some("c"),
579            Some("d"),
580            Some("e"),
581            Some("f"),
582            Some("g"),
583        ]));
584        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
585        let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
586
587        _test_merge(
588            &[vec![b1], vec![b2]],
589            &[
590                "+-----+---+-------------------------------+",
591                "| a   | b | c                             |",
592                "+-----+---+-------------------------------+",
593                "| 1   | a | 1970-01-01T00:00:00.000000008 |",
594                "| 2   | b | 1970-01-01T00:00:00.000000007 |",
595                "| 70  | c | 1970-01-01T00:00:00.000000004 |",
596                "| 7   | c | 1970-01-01T00:00:00.000000006 |",
597                "| 9   | d | 1970-01-01T00:00:00.000000005 |",
598                "| 90  | d | 1970-01-01T00:00:00.000000006 |",
599                "| 30  | e | 1970-01-01T00:00:00.000000002 |",
600                "| 3   | e | 1970-01-01T00:00:00.000000008 |",
601                "| 100 | f | 1970-01-01T00:00:00.000000002 |",
602                "| 110 | g | 1970-01-01T00:00:00.000000006 |",
603                "+-----+---+-------------------------------+",
604            ],
605            task_ctx,
606        )
607        .await;
608    }
609
610    #[tokio::test]
611    async fn test_merge_no_overlap() {
612        let task_ctx = Arc::new(TaskContext::default());
613        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
614        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
615            Some("a"),
616            Some("b"),
617            Some("c"),
618            Some("d"),
619            Some("e"),
620        ]));
621        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
622        let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
623
624        let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 70, 90, 30]));
625        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
626            Some("f"),
627            Some("g"),
628            Some("h"),
629            Some("i"),
630            Some("j"),
631        ]));
632        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
633        let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
634
635        _test_merge(
636            &[vec![b1], vec![b2]],
637            &[
638                "+----+---+-------------------------------+",
639                "| a  | b | c                             |",
640                "+----+---+-------------------------------+",
641                "| 1  | a | 1970-01-01T00:00:00.000000008 |",
642                "| 2  | b | 1970-01-01T00:00:00.000000007 |",
643                "| 7  | c | 1970-01-01T00:00:00.000000006 |",
644                "| 9  | d | 1970-01-01T00:00:00.000000005 |",
645                "| 3  | e | 1970-01-01T00:00:00.000000008 |",
646                "| 10 | f | 1970-01-01T00:00:00.000000004 |",
647                "| 20 | g | 1970-01-01T00:00:00.000000006 |",
648                "| 70 | h | 1970-01-01T00:00:00.000000002 |",
649                "| 90 | i | 1970-01-01T00:00:00.000000002 |",
650                "| 30 | j | 1970-01-01T00:00:00.000000006 |",
651                "+----+---+-------------------------------+",
652            ],
653            task_ctx,
654        )
655        .await;
656    }
657
658    #[tokio::test]
659    async fn test_merge_three_partitions() {
660        let task_ctx = Arc::new(TaskContext::default());
661        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
662        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
663            Some("a"),
664            Some("b"),
665            Some("c"),
666            Some("d"),
667            Some("f"),
668        ]));
669        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
670        let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
671
672        let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 70, 90, 30]));
673        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
674            Some("e"),
675            Some("g"),
676            Some("h"),
677            Some("i"),
678            Some("j"),
679        ]));
680        let c: ArrayRef =
681            Arc::new(TimestampNanosecondArray::from(vec![40, 60, 20, 20, 60]));
682        let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
683
684        let a: ArrayRef = Arc::new(Int32Array::from(vec![100, 200, 700, 900, 300]));
685        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
686            Some("f"),
687            Some("g"),
688            Some("h"),
689            Some("i"),
690            Some("j"),
691        ]));
692        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
693        let b3 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
694
695        _test_merge(
696            &[vec![b1], vec![b2], vec![b3]],
697            &[
698                "+-----+---+-------------------------------+",
699                "| a   | b | c                             |",
700                "+-----+---+-------------------------------+",
701                "| 1   | a | 1970-01-01T00:00:00.000000008 |",
702                "| 2   | b | 1970-01-01T00:00:00.000000007 |",
703                "| 7   | c | 1970-01-01T00:00:00.000000006 |",
704                "| 9   | d | 1970-01-01T00:00:00.000000005 |",
705                "| 10  | e | 1970-01-01T00:00:00.000000040 |",
706                "| 100 | f | 1970-01-01T00:00:00.000000004 |",
707                "| 3   | f | 1970-01-01T00:00:00.000000008 |",
708                "| 200 | g | 1970-01-01T00:00:00.000000006 |",
709                "| 20  | g | 1970-01-01T00:00:00.000000060 |",
710                "| 700 | h | 1970-01-01T00:00:00.000000002 |",
711                "| 70  | h | 1970-01-01T00:00:00.000000020 |",
712                "| 900 | i | 1970-01-01T00:00:00.000000002 |",
713                "| 90  | i | 1970-01-01T00:00:00.000000020 |",
714                "| 300 | j | 1970-01-01T00:00:00.000000006 |",
715                "| 30  | j | 1970-01-01T00:00:00.000000060 |",
716                "+-----+---+-------------------------------+",
717            ],
718            task_ctx,
719        )
720        .await;
721    }
722
723    async fn _test_merge(
724        partitions: &[Vec<RecordBatch>],
725        exp: &[&str],
726        context: Arc<TaskContext>,
727    ) {
728        let schema = partitions[0][0].schema();
729        let sort = LexOrdering::new(vec![
730            PhysicalSortExpr {
731                expr: col("b", &schema).unwrap(),
732                options: Default::default(),
733            },
734            PhysicalSortExpr {
735                expr: col("c", &schema).unwrap(),
736                options: Default::default(),
737            },
738        ]);
739        let exec = TestMemoryExec::try_new_exec(partitions, schema, None).unwrap();
740        let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
741
742        let collected = collect(merge, context).await.unwrap();
743        assert_batches_eq!(exp, collected.as_slice());
744    }
745
746    async fn sorted_merge(
747        input: Arc<dyn ExecutionPlan>,
748        sort: LexOrdering,
749        context: Arc<TaskContext>,
750    ) -> RecordBatch {
751        let merge = Arc::new(SortPreservingMergeExec::new(sort, input));
752        let mut result = collect(merge, context).await.unwrap();
753        assert_eq!(result.len(), 1);
754        result.remove(0)
755    }
756
757    async fn partition_sort(
758        input: Arc<dyn ExecutionPlan>,
759        sort: LexOrdering,
760        context: Arc<TaskContext>,
761    ) -> RecordBatch {
762        let sort_exec =
763            Arc::new(SortExec::new(sort.clone(), input).with_preserve_partitioning(true));
764        sorted_merge(sort_exec, sort, context).await
765    }
766
767    async fn basic_sort(
768        src: Arc<dyn ExecutionPlan>,
769        sort: LexOrdering,
770        context: Arc<TaskContext>,
771    ) -> RecordBatch {
772        let merge = Arc::new(CoalescePartitionsExec::new(src));
773        let sort_exec = Arc::new(SortExec::new(sort, merge));
774        let mut result = collect(sort_exec, context).await.unwrap();
775        assert_eq!(result.len(), 1);
776        result.remove(0)
777    }
778
779    #[tokio::test]
780    async fn test_partition_sort() -> Result<()> {
781        let task_ctx = Arc::new(TaskContext::default());
782        let partitions = 4;
783        let csv = test::scan_partitioned(partitions);
784        let schema = csv.schema();
785
786        let sort = LexOrdering::new(vec![PhysicalSortExpr {
787            expr: col("i", &schema).unwrap(),
788            options: SortOptions {
789                descending: true,
790                nulls_first: true,
791            },
792        }]);
793
794        let basic =
795            basic_sort(Arc::clone(&csv), sort.clone(), Arc::clone(&task_ctx)).await;
796        let partition = partition_sort(csv, sort, Arc::clone(&task_ctx)).await;
797
798        let basic = arrow::util::pretty::pretty_format_batches(&[basic])
799            .unwrap()
800            .to_string();
801        let partition = arrow::util::pretty::pretty_format_batches(&[partition])
802            .unwrap()
803            .to_string();
804
805        assert_eq!(
806            basic, partition,
807            "basic:\n\n{basic}\n\npartition:\n\n{partition}\n\n"
808        );
809
810        Ok(())
811    }
812
813    // Split the provided record batch into multiple batch_size record batches
814    fn split_batch(sorted: &RecordBatch, batch_size: usize) -> Vec<RecordBatch> {
815        let batches = sorted.num_rows().div_ceil(batch_size);
816
817        // Split the sorted RecordBatch into multiple
818        (0..batches)
819            .map(|batch_idx| {
820                let columns = (0..sorted.num_columns())
821                    .map(|column_idx| {
822                        let length =
823                            batch_size.min(sorted.num_rows() - batch_idx * batch_size);
824
825                        sorted
826                            .column(column_idx)
827                            .slice(batch_idx * batch_size, length)
828                    })
829                    .collect();
830
831                RecordBatch::try_new(sorted.schema(), columns).unwrap()
832            })
833            .collect()
834    }
835
836    async fn sorted_partitioned_input(
837        sort: LexOrdering,
838        sizes: &[usize],
839        context: Arc<TaskContext>,
840    ) -> Result<Arc<dyn ExecutionPlan>> {
841        let partitions = 4;
842        let csv = test::scan_partitioned(partitions);
843
844        let sorted = basic_sort(csv, sort, context).await;
845        let split: Vec<_> = sizes.iter().map(|x| split_batch(&sorted, *x)).collect();
846
847        Ok(TestMemoryExec::try_new_exec(&split, sorted.schema(), None).unwrap())
848    }
849
850    #[tokio::test]
851    async fn test_partition_sort_streaming_input() -> Result<()> {
852        let task_ctx = Arc::new(TaskContext::default());
853        let schema = make_partition(11).schema();
854        let sort = LexOrdering::new(vec![PhysicalSortExpr {
855            expr: col("i", &schema).unwrap(),
856            options: Default::default(),
857        }]);
858
859        let input =
860            sorted_partitioned_input(sort.clone(), &[10, 3, 11], Arc::clone(&task_ctx))
861                .await?;
862        let basic =
863            basic_sort(Arc::clone(&input), sort.clone(), Arc::clone(&task_ctx)).await;
864        let partition = sorted_merge(input, sort, Arc::clone(&task_ctx)).await;
865
866        assert_eq!(basic.num_rows(), 1200);
867        assert_eq!(partition.num_rows(), 1200);
868
869        let basic = arrow::util::pretty::pretty_format_batches(&[basic])
870            .unwrap()
871            .to_string();
872        let partition = arrow::util::pretty::pretty_format_batches(&[partition])
873            .unwrap()
874            .to_string();
875
876        assert_eq!(basic, partition);
877
878        Ok(())
879    }
880
881    #[tokio::test]
882    async fn test_partition_sort_streaming_input_output() -> Result<()> {
883        let schema = make_partition(11).schema();
884        let sort = LexOrdering::new(vec![PhysicalSortExpr {
885            expr: col("i", &schema).unwrap(),
886            options: Default::default(),
887        }]);
888
889        // Test streaming with default batch size
890        let task_ctx = Arc::new(TaskContext::default());
891        let input =
892            sorted_partitioned_input(sort.clone(), &[10, 5, 13], Arc::clone(&task_ctx))
893                .await?;
894        let basic = basic_sort(Arc::clone(&input), sort.clone(), task_ctx).await;
895
896        // batch size of 23
897        let task_ctx = TaskContext::default()
898            .with_session_config(SessionConfig::new().with_batch_size(23));
899        let task_ctx = Arc::new(task_ctx);
900
901        let merge = Arc::new(SortPreservingMergeExec::new(sort, input));
902        let merged = collect(merge, task_ctx).await.unwrap();
903
904        assert_eq!(merged.len(), 53);
905
906        assert_eq!(basic.num_rows(), 1200);
907        assert_eq!(merged.iter().map(|x| x.num_rows()).sum::<usize>(), 1200);
908
909        let basic = arrow::util::pretty::pretty_format_batches(&[basic])
910            .unwrap()
911            .to_string();
912        let partition = arrow::util::pretty::pretty_format_batches(merged.as_slice())
913            .unwrap()
914            .to_string();
915
916        assert_eq!(basic, partition);
917
918        Ok(())
919    }
920
921    #[tokio::test]
922    async fn test_nulls() {
923        let task_ctx = Arc::new(TaskContext::default());
924        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
925        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
926            None,
927            Some("a"),
928            Some("b"),
929            Some("d"),
930            Some("e"),
931        ]));
932        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![
933            Some(8),
934            None,
935            Some(6),
936            None,
937            Some(4),
938        ]));
939        let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
940
941        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
942        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
943            None,
944            Some("b"),
945            Some("g"),
946            Some("h"),
947            Some("i"),
948        ]));
949        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![
950            Some(8),
951            None,
952            Some(5),
953            None,
954            Some(4),
955        ]));
956        let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
957        let schema = b1.schema();
958
959        let sort = LexOrdering::new(vec![
960            PhysicalSortExpr {
961                expr: col("b", &schema).unwrap(),
962                options: SortOptions {
963                    descending: false,
964                    nulls_first: true,
965                },
966            },
967            PhysicalSortExpr {
968                expr: col("c", &schema).unwrap(),
969                options: SortOptions {
970                    descending: false,
971                    nulls_first: false,
972                },
973            },
974        ]);
975        let exec =
976            TestMemoryExec::try_new_exec(&[vec![b1], vec![b2]], schema, None).unwrap();
977        let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
978
979        let collected = collect(merge, task_ctx).await.unwrap();
980        assert_eq!(collected.len(), 1);
981
982        assert_batches_eq!(
983            &[
984                "+---+---+-------------------------------+",
985                "| a | b | c                             |",
986                "+---+---+-------------------------------+",
987                "| 1 |   | 1970-01-01T00:00:00.000000008 |",
988                "| 1 |   | 1970-01-01T00:00:00.000000008 |",
989                "| 2 | a |                               |",
990                "| 7 | b | 1970-01-01T00:00:00.000000006 |",
991                "| 2 | b |                               |",
992                "| 9 | d |                               |",
993                "| 3 | e | 1970-01-01T00:00:00.000000004 |",
994                "| 3 | g | 1970-01-01T00:00:00.000000005 |",
995                "| 4 | h |                               |",
996                "| 5 | i | 1970-01-01T00:00:00.000000004 |",
997                "+---+---+-------------------------------+",
998            ],
999            collected.as_slice()
1000        );
1001    }
1002
1003    #[tokio::test]
1004    async fn test_sort_merge_single_partition_with_fetch() {
1005        let task_ctx = Arc::new(TaskContext::default());
1006        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
1007        let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
1008        let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
1009        let schema = batch.schema();
1010
1011        let sort = LexOrdering::new(vec![PhysicalSortExpr {
1012            expr: col("b", &schema).unwrap(),
1013            options: SortOptions {
1014                descending: false,
1015                nulls_first: true,
1016            },
1017        }]);
1018        let exec = TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap();
1019        let merge =
1020            Arc::new(SortPreservingMergeExec::new(sort, exec).with_fetch(Some(2)));
1021
1022        let collected = collect(merge, task_ctx).await.unwrap();
1023        assert_eq!(collected.len(), 1);
1024
1025        assert_batches_eq!(
1026            &[
1027                "+---+---+",
1028                "| a | b |",
1029                "+---+---+",
1030                "| 1 | a |",
1031                "| 2 | b |",
1032                "+---+---+",
1033            ],
1034            collected.as_slice()
1035        );
1036    }
1037
1038    #[tokio::test]
1039    async fn test_sort_merge_single_partition_without_fetch() {
1040        let task_ctx = Arc::new(TaskContext::default());
1041        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
1042        let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
1043        let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
1044        let schema = batch.schema();
1045
1046        let sort = LexOrdering::new(vec![PhysicalSortExpr {
1047            expr: col("b", &schema).unwrap(),
1048            options: SortOptions {
1049                descending: false,
1050                nulls_first: true,
1051            },
1052        }]);
1053        let exec = TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap();
1054        let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
1055
1056        let collected = collect(merge, task_ctx).await.unwrap();
1057        assert_eq!(collected.len(), 1);
1058
1059        assert_batches_eq!(
1060            &[
1061                "+---+---+",
1062                "| a | b |",
1063                "+---+---+",
1064                "| 1 | a |",
1065                "| 2 | b |",
1066                "| 7 | c |",
1067                "| 9 | d |",
1068                "| 3 | e |",
1069                "+---+---+",
1070            ],
1071            collected.as_slice()
1072        );
1073    }
1074
1075    #[tokio::test]
1076    async fn test_async() -> Result<()> {
1077        let task_ctx = Arc::new(TaskContext::default());
1078        let schema = make_partition(11).schema();
1079        let sort = LexOrdering::new(vec![PhysicalSortExpr {
1080            expr: col("i", &schema).unwrap(),
1081            options: SortOptions::default(),
1082        }]);
1083
1084        let batches =
1085            sorted_partitioned_input(sort.clone(), &[5, 7, 3], Arc::clone(&task_ctx))
1086                .await?;
1087
1088        let partition_count = batches.output_partitioning().partition_count();
1089        let mut streams = Vec::with_capacity(partition_count);
1090
1091        for partition in 0..partition_count {
1092            let mut builder = RecordBatchReceiverStream::builder(Arc::clone(&schema), 1);
1093
1094            let sender = builder.tx();
1095
1096            let mut stream = batches.execute(partition, Arc::clone(&task_ctx)).unwrap();
1097            builder.spawn(async move {
1098                while let Some(batch) = stream.next().await {
1099                    sender.send(batch).await.unwrap();
1100                    // This causes the MergeStream to wait for more input
1101                    tokio::time::sleep(Duration::from_millis(10)).await;
1102                }
1103
1104                Ok(())
1105            });
1106
1107            streams.push(builder.build());
1108        }
1109
1110        let metrics = ExecutionPlanMetricsSet::new();
1111        let reservation =
1112            MemoryConsumer::new("test").register(&task_ctx.runtime_env().memory_pool);
1113
1114        let fetch = None;
1115        let merge_stream = StreamingMergeBuilder::new()
1116            .with_streams(streams)
1117            .with_schema(batches.schema())
1118            .with_expressions(sort.as_ref())
1119            .with_metrics(BaselineMetrics::new(&metrics, 0))
1120            .with_batch_size(task_ctx.session_config().batch_size())
1121            .with_fetch(fetch)
1122            .with_reservation(reservation)
1123            .build()?;
1124
1125        let mut merged = common::collect(merge_stream).await.unwrap();
1126
1127        assert_eq!(merged.len(), 1);
1128        let merged = merged.remove(0);
1129        let basic = basic_sort(batches, sort.clone(), Arc::clone(&task_ctx)).await;
1130
1131        let basic = arrow::util::pretty::pretty_format_batches(&[basic])
1132            .unwrap()
1133            .to_string();
1134        let partition = arrow::util::pretty::pretty_format_batches(&[merged])
1135            .unwrap()
1136            .to_string();
1137
1138        assert_eq!(
1139            basic, partition,
1140            "basic:\n\n{basic}\n\npartition:\n\n{partition}\n\n"
1141        );
1142
1143        Ok(())
1144    }
1145
1146    #[tokio::test]
1147    async fn test_merge_metrics() {
1148        let task_ctx = Arc::new(TaskContext::default());
1149        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
1150        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("a"), Some("c")]));
1151        let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
1152
1153        let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20]));
1154        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("b"), Some("d")]));
1155        let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
1156
1157        let schema = b1.schema();
1158        let sort = LexOrdering::new(vec![PhysicalSortExpr {
1159            expr: col("b", &schema).unwrap(),
1160            options: Default::default(),
1161        }]);
1162        let exec =
1163            TestMemoryExec::try_new_exec(&[vec![b1], vec![b2]], schema, None).unwrap();
1164        let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
1165
1166        let collected = collect(Arc::clone(&merge) as Arc<dyn ExecutionPlan>, task_ctx)
1167            .await
1168            .unwrap();
1169        let expected = [
1170            "+----+---+",
1171            "| a  | b |",
1172            "+----+---+",
1173            "| 1  | a |",
1174            "| 10 | b |",
1175            "| 2  | c |",
1176            "| 20 | d |",
1177            "+----+---+",
1178        ];
1179        assert_batches_eq!(expected, collected.as_slice());
1180
1181        // Now, validate metrics
1182        let metrics = merge.metrics().unwrap();
1183
1184        assert_eq!(metrics.output_rows().unwrap(), 4);
1185        assert!(metrics.elapsed_compute().unwrap() > 0);
1186
1187        let mut saw_start = false;
1188        let mut saw_end = false;
1189        metrics.iter().for_each(|m| match m.value() {
1190            MetricValue::StartTimestamp(ts) => {
1191                saw_start = true;
1192                assert!(nanos_from_timestamp(ts) > 0);
1193            }
1194            MetricValue::EndTimestamp(ts) => {
1195                saw_end = true;
1196                assert!(nanos_from_timestamp(ts) > 0);
1197            }
1198            _ => {}
1199        });
1200
1201        assert!(saw_start);
1202        assert!(saw_end);
1203    }
1204
1205    fn nanos_from_timestamp(ts: &Timestamp) -> i64 {
1206        ts.value().unwrap().timestamp_nanos_opt().unwrap()
1207    }
1208
1209    #[tokio::test]
1210    async fn test_drop_cancel() -> Result<()> {
1211        let task_ctx = Arc::new(TaskContext::default());
1212        let schema =
1213            Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]));
1214
1215        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 2));
1216        let refs = blocking_exec.refs();
1217        let sort_preserving_merge_exec = Arc::new(SortPreservingMergeExec::new(
1218            LexOrdering::new(vec![PhysicalSortExpr {
1219                expr: col("a", &schema)?,
1220                options: SortOptions::default(),
1221            }]),
1222            blocking_exec,
1223        ));
1224
1225        let fut = collect(sort_preserving_merge_exec, task_ctx);
1226        let mut fut = fut.boxed();
1227
1228        assert_is_pending(&mut fut);
1229        drop(fut);
1230        assert_strong_count_converges_to_zero(refs).await;
1231
1232        Ok(())
1233    }
1234
1235    #[tokio::test]
1236    async fn test_stable_sort() {
1237        let task_ctx = Arc::new(TaskContext::default());
1238
1239        // Create record batches like:
1240        // batch_number |value
1241        // -------------+------
1242        //    1         | A
1243        //    1         | B
1244        //
1245        // Ensure that the output is in the same order the batches were fed
1246        let partitions: Vec<Vec<RecordBatch>> = (0..10)
1247            .map(|batch_number| {
1248                let batch_number: Int32Array =
1249                    vec![Some(batch_number), Some(batch_number)]
1250                        .into_iter()
1251                        .collect();
1252                let value: StringArray = vec![Some("A"), Some("B")].into_iter().collect();
1253
1254                let batch = RecordBatch::try_from_iter(vec![
1255                    ("batch_number", Arc::new(batch_number) as ArrayRef),
1256                    ("value", Arc::new(value) as ArrayRef),
1257                ])
1258                .unwrap();
1259
1260                vec![batch]
1261            })
1262            .collect();
1263
1264        let schema = partitions[0][0].schema();
1265
1266        let sort = LexOrdering::new(vec![PhysicalSortExpr {
1267            expr: col("value", &schema).unwrap(),
1268            options: SortOptions {
1269                descending: false,
1270                nulls_first: true,
1271            },
1272        }]);
1273
1274        let exec = TestMemoryExec::try_new_exec(&partitions, schema, None).unwrap();
1275        let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
1276
1277        let collected = collect(merge, task_ctx).await.unwrap();
1278        assert_eq!(collected.len(), 1);
1279
1280        // Expect the data to be sorted first by "batch_number" (because
1281        // that was the order it was fed in, even though only "value"
1282        // is in the sort key)
1283        assert_batches_eq!(
1284            &[
1285                "+--------------+-------+",
1286                "| batch_number | value |",
1287                "+--------------+-------+",
1288                "| 0            | A     |",
1289                "| 1            | A     |",
1290                "| 2            | A     |",
1291                "| 3            | A     |",
1292                "| 4            | A     |",
1293                "| 5            | A     |",
1294                "| 6            | A     |",
1295                "| 7            | A     |",
1296                "| 8            | A     |",
1297                "| 9            | A     |",
1298                "| 0            | B     |",
1299                "| 1            | B     |",
1300                "| 2            | B     |",
1301                "| 3            | B     |",
1302                "| 4            | B     |",
1303                "| 5            | B     |",
1304                "| 6            | B     |",
1305                "| 7            | B     |",
1306                "| 8            | B     |",
1307                "| 9            | B     |",
1308                "+--------------+-------+",
1309            ],
1310            collected.as_slice()
1311        );
1312    }
1313
1314    /// It returns pending for the 2nd partition until the 3rd partition is polled. The 1st
1315    /// partition is exhausted from the start, and if it is polled more than one, it panics.
1316    #[derive(Debug, Clone)]
1317    struct CongestedExec {
1318        schema: Schema,
1319        cache: PlanProperties,
1320        congestion_cleared: Arc<Mutex<bool>>,
1321    }
1322
1323    impl CongestedExec {
1324        fn compute_properties(schema: SchemaRef) -> PlanProperties {
1325            let columns = schema
1326                .fields
1327                .iter()
1328                .enumerate()
1329                .map(|(i, f)| Arc::new(Column::new(f.name(), i)) as Arc<dyn PhysicalExpr>)
1330                .collect::<Vec<_>>();
1331            let mut eq_properties = EquivalenceProperties::new(schema);
1332            eq_properties.add_new_orderings(vec![columns
1333                .iter()
1334                .map(|expr| PhysicalSortExpr::new_default(Arc::clone(expr)))
1335                .collect::<LexOrdering>()]);
1336            PlanProperties::new(
1337                eq_properties,
1338                Partitioning::Hash(columns, 3),
1339                EmissionType::Incremental,
1340                Boundedness::Unbounded {
1341                    requires_infinite_memory: false,
1342                },
1343            )
1344        }
1345    }
1346
1347    impl ExecutionPlan for CongestedExec {
1348        fn name(&self) -> &'static str {
1349            Self::static_name()
1350        }
1351        fn as_any(&self) -> &dyn Any {
1352            self
1353        }
1354        fn properties(&self) -> &PlanProperties {
1355            &self.cache
1356        }
1357        fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1358            vec![]
1359        }
1360        fn with_new_children(
1361            self: Arc<Self>,
1362            _: Vec<Arc<dyn ExecutionPlan>>,
1363        ) -> Result<Arc<dyn ExecutionPlan>> {
1364            Ok(self)
1365        }
1366        fn execute(
1367            &self,
1368            partition: usize,
1369            _context: Arc<TaskContext>,
1370        ) -> Result<SendableRecordBatchStream> {
1371            Ok(Box::pin(CongestedStream {
1372                schema: Arc::new(self.schema.clone()),
1373                none_polled_once: false,
1374                congestion_cleared: Arc::clone(&self.congestion_cleared),
1375                partition,
1376            }))
1377        }
1378    }
1379
1380    impl DisplayAs for CongestedExec {
1381        fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
1382            match t {
1383                DisplayFormatType::Default | DisplayFormatType::Verbose => {
1384                    write!(f, "CongestedExec",).unwrap()
1385                }
1386            }
1387            Ok(())
1388        }
1389    }
1390
1391    /// It returns pending for the 2nd partition until the 3rd partition is polled. The 1st
1392    /// partition is exhausted from the start, and if it is polled more than once, it panics.
1393    #[derive(Debug)]
1394    pub struct CongestedStream {
1395        schema: SchemaRef,
1396        none_polled_once: bool,
1397        congestion_cleared: Arc<Mutex<bool>>,
1398        partition: usize,
1399    }
1400
1401    impl Stream for CongestedStream {
1402        type Item = Result<RecordBatch>;
1403        fn poll_next(
1404            mut self: Pin<&mut Self>,
1405            _cx: &mut Context<'_>,
1406        ) -> Poll<Option<Self::Item>> {
1407            match self.partition {
1408                0 => {
1409                    if self.none_polled_once {
1410                        panic!("Exhausted stream is polled more than one")
1411                    } else {
1412                        self.none_polled_once = true;
1413                        Poll::Ready(None)
1414                    }
1415                }
1416                1 => {
1417                    let cleared = self.congestion_cleared.lock().unwrap();
1418                    if *cleared {
1419                        Poll::Ready(None)
1420                    } else {
1421                        Poll::Pending
1422                    }
1423                }
1424                2 => {
1425                    let mut cleared = self.congestion_cleared.lock().unwrap();
1426                    *cleared = true;
1427                    Poll::Ready(None)
1428                }
1429                _ => unreachable!(),
1430            }
1431        }
1432    }
1433
1434    impl RecordBatchStream for CongestedStream {
1435        fn schema(&self) -> SchemaRef {
1436            Arc::clone(&self.schema)
1437        }
1438    }
1439
1440    #[tokio::test]
1441    async fn test_spm_congestion() -> Result<()> {
1442        let task_ctx = Arc::new(TaskContext::default());
1443        let schema = Schema::new(vec![Field::new("c1", DataType::UInt64, false)]);
1444        let source = CongestedExec {
1445            schema: schema.clone(),
1446            cache: CongestedExec::compute_properties(Arc::new(schema.clone())),
1447            congestion_cleared: Arc::new(Mutex::new(false)),
1448        };
1449        let spm = SortPreservingMergeExec::new(
1450            LexOrdering::new(vec![PhysicalSortExpr::new_default(Arc::new(Column::new(
1451                "c1", 0,
1452            )))]),
1453            Arc::new(source),
1454        );
1455        let spm_task = SpawnedTask::spawn(collect(Arc::new(spm), task_ctx));
1456
1457        let result = timeout(Duration::from_secs(3), spm_task.join()).await;
1458        match result {
1459            Ok(Ok(Ok(_batches))) => Ok(()),
1460            Ok(Ok(Err(e))) => Err(e),
1461            Ok(Err(_)) => Err(DataFusionError::Execution(
1462                "SortPreservingMerge task panicked or was cancelled".to_string(),
1463            )),
1464            Err(_) => Err(DataFusionError::Execution(
1465                "SortPreservingMerge caused a deadlock".to_string(),
1466            )),
1467        }
1468    }
1469}