datafusion_physical_plan/sorts/
partial_sort.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Partial Sort deals with input data that partially
19//! satisfies the required sort order. Such an input data can be
20//! partitioned into segments where each segment already has the
21//! required information for lexicographic sorting so sorting
22//! can be done without loading the entire dataset.
23//!
24//! Consider a sort plan having an input with ordering `a ASC, b ASC`
25//!
26//! ```text
27//! +---+---+---+
28//! | a | b | d |
29//! +---+---+---+
30//! | 0 | 0 | 3 |
31//! | 0 | 0 | 2 |
32//! | 0 | 1 | 1 |
33//! | 0 | 2 | 0 |
34//! +---+---+---+
35//!```
36//!
37//! and required ordering for the plan is `a ASC, b ASC, d ASC`.
38//! The first 3 rows(segment) can be sorted as the segment already
39//! has the required information for the sort, but the last row
40//! requires further information as the input can continue with a
41//! batch with a starting row where a and b does not change as below
42//!
43//! ```text
44//! +---+---+---+
45//! | a | b | d |
46//! +---+---+---+
47//! | 0 | 2 | 4 |
48//! +---+---+---+
49//!```
50//!
51//! The plan concats incoming data with such last rows of previous input
52//! and continues partial sorting of the segments.
53
54use std::any::Any;
55use std::fmt::Debug;
56use std::pin::Pin;
57use std::sync::Arc;
58use std::task::{Context, Poll};
59
60use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
61use crate::sorts::sort::sort_batch;
62use crate::{
63    DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
64    Partitioning, PlanProperties, SendableRecordBatchStream, Statistics,
65};
66
67use arrow::compute::concat_batches;
68use arrow::datatypes::SchemaRef;
69use arrow::record_batch::RecordBatch;
70use datafusion_common::utils::evaluate_partition_ranges;
71use datafusion_common::Result;
72use datafusion_execution::{RecordBatchStream, TaskContext};
73use datafusion_physical_expr::LexOrdering;
74
75use futures::{ready, Stream, StreamExt};
76use log::trace;
77
78/// Partial Sort execution plan.
79#[derive(Debug, Clone)]
80pub struct PartialSortExec {
81    /// Input schema
82    pub(crate) input: Arc<dyn ExecutionPlan>,
83    /// Sort expressions
84    expr: LexOrdering,
85    /// Length of continuous matching columns of input that satisfy
86    /// the required ordering for the sort
87    common_prefix_length: usize,
88    /// Containing all metrics set created during sort
89    metrics_set: ExecutionPlanMetricsSet,
90    /// Preserve partitions of input plan. If false, the input partitions
91    /// will be sorted and merged into a single output partition.
92    preserve_partitioning: bool,
93    /// Fetch highest/lowest n results
94    fetch: Option<usize>,
95    /// Cache holding plan properties like equivalences, output partitioning etc.
96    cache: PlanProperties,
97}
98
99impl PartialSortExec {
100    /// Create a new partial sort execution plan
101    pub fn new(
102        expr: LexOrdering,
103        input: Arc<dyn ExecutionPlan>,
104        common_prefix_length: usize,
105    ) -> Self {
106        debug_assert!(common_prefix_length > 0);
107        let preserve_partitioning = false;
108        let cache = Self::compute_properties(&input, expr.clone(), preserve_partitioning);
109        Self {
110            input,
111            expr,
112            common_prefix_length,
113            metrics_set: ExecutionPlanMetricsSet::new(),
114            preserve_partitioning,
115            fetch: None,
116            cache,
117        }
118    }
119
120    /// Whether this `PartialSortExec` preserves partitioning of the children
121    pub fn preserve_partitioning(&self) -> bool {
122        self.preserve_partitioning
123    }
124
125    /// Specify the partitioning behavior of this partial sort exec
126    ///
127    /// If `preserve_partitioning` is true, sorts each partition
128    /// individually, producing one sorted stream for each input partition.
129    ///
130    /// If `preserve_partitioning` is false, sorts and merges all
131    /// input partitions producing a single, sorted partition.
132    pub fn with_preserve_partitioning(mut self, preserve_partitioning: bool) -> Self {
133        self.preserve_partitioning = preserve_partitioning;
134        self.cache = self
135            .cache
136            .with_partitioning(Self::output_partitioning_helper(
137                &self.input,
138                self.preserve_partitioning,
139            ));
140        self
141    }
142
143    /// Modify how many rows to include in the result
144    ///
145    /// If None, then all rows will be returned, in sorted order.
146    /// If Some, then only the top `fetch` rows will be returned.
147    /// This can reduce the memory pressure required by the sort
148    /// operation since rows that are not going to be included
149    /// can be dropped.
150    pub fn with_fetch(mut self, fetch: Option<usize>) -> Self {
151        self.fetch = fetch;
152        self
153    }
154
155    /// Input schema
156    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
157        &self.input
158    }
159
160    /// Sort expressions
161    pub fn expr(&self) -> &LexOrdering {
162        self.expr.as_ref()
163    }
164
165    /// If `Some(fetch)`, limits output to only the first "fetch" items
166    pub fn fetch(&self) -> Option<usize> {
167        self.fetch
168    }
169
170    /// Common prefix length
171    pub fn common_prefix_length(&self) -> usize {
172        self.common_prefix_length
173    }
174
175    fn output_partitioning_helper(
176        input: &Arc<dyn ExecutionPlan>,
177        preserve_partitioning: bool,
178    ) -> Partitioning {
179        // Get output partitioning:
180        if preserve_partitioning {
181            input.output_partitioning().clone()
182        } else {
183            Partitioning::UnknownPartitioning(1)
184        }
185    }
186
187    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
188    fn compute_properties(
189        input: &Arc<dyn ExecutionPlan>,
190        sort_exprs: LexOrdering,
191        preserve_partitioning: bool,
192    ) -> PlanProperties {
193        // Calculate equivalence properties; i.e. reset the ordering equivalence
194        // class with the new ordering:
195        let eq_properties = input
196            .equivalence_properties()
197            .clone()
198            .with_reorder(sort_exprs);
199
200        // Get output partitioning:
201        let output_partitioning =
202            Self::output_partitioning_helper(input, preserve_partitioning);
203
204        PlanProperties::new(
205            eq_properties,
206            output_partitioning,
207            input.pipeline_behavior(),
208            input.boundedness(),
209        )
210    }
211}
212
213impl DisplayAs for PartialSortExec {
214    fn fmt_as(
215        &self,
216        t: DisplayFormatType,
217        f: &mut std::fmt::Formatter,
218    ) -> std::fmt::Result {
219        match t {
220            DisplayFormatType::Default | DisplayFormatType::Verbose => {
221                let common_prefix_length = self.common_prefix_length;
222                match self.fetch {
223                    Some(fetch) => {
224                        write!(f, "PartialSortExec: TopK(fetch={fetch}), expr=[{}], common_prefix_length=[{common_prefix_length}]", self.expr)
225                    }
226                    None => write!(f, "PartialSortExec: expr=[{}], common_prefix_length=[{common_prefix_length}]", self.expr),
227                }
228            }
229        }
230    }
231}
232
233impl ExecutionPlan for PartialSortExec {
234    fn name(&self) -> &'static str {
235        "PartialSortExec"
236    }
237
238    fn as_any(&self) -> &dyn Any {
239        self
240    }
241
242    fn properties(&self) -> &PlanProperties {
243        &self.cache
244    }
245
246    fn fetch(&self) -> Option<usize> {
247        self.fetch
248    }
249
250    fn required_input_distribution(&self) -> Vec<Distribution> {
251        if self.preserve_partitioning {
252            vec![Distribution::UnspecifiedDistribution]
253        } else {
254            vec![Distribution::SinglePartition]
255        }
256    }
257
258    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
259        vec![false]
260    }
261
262    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
263        vec![&self.input]
264    }
265
266    fn with_new_children(
267        self: Arc<Self>,
268        children: Vec<Arc<dyn ExecutionPlan>>,
269    ) -> Result<Arc<dyn ExecutionPlan>> {
270        let new_partial_sort = PartialSortExec::new(
271            self.expr.clone(),
272            Arc::clone(&children[0]),
273            self.common_prefix_length,
274        )
275        .with_fetch(self.fetch)
276        .with_preserve_partitioning(self.preserve_partitioning);
277
278        Ok(Arc::new(new_partial_sort))
279    }
280
281    fn execute(
282        &self,
283        partition: usize,
284        context: Arc<TaskContext>,
285    ) -> Result<SendableRecordBatchStream> {
286        trace!("Start PartialSortExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id());
287
288        let input = self.input.execute(partition, Arc::clone(&context))?;
289
290        trace!(
291            "End PartialSortExec's input.execute for partition: {}",
292            partition
293        );
294
295        // Make sure common prefix length is larger than 0
296        // Otherwise, we should use SortExec.
297        debug_assert!(self.common_prefix_length > 0);
298
299        Ok(Box::pin(PartialSortStream {
300            input,
301            expr: self.expr.clone(),
302            common_prefix_length: self.common_prefix_length,
303            in_mem_batches: vec![],
304            fetch: self.fetch,
305            is_closed: false,
306            baseline_metrics: BaselineMetrics::new(&self.metrics_set, partition),
307        }))
308    }
309
310    fn metrics(&self) -> Option<MetricsSet> {
311        Some(self.metrics_set.clone_inner())
312    }
313
314    fn statistics(&self) -> Result<Statistics> {
315        self.input.statistics()
316    }
317}
318
319struct PartialSortStream {
320    /// The input plan
321    input: SendableRecordBatchStream,
322    /// Sort expressions
323    expr: LexOrdering,
324    /// Length of prefix common to input ordering and required ordering of plan
325    /// should be more than 0 otherwise PartialSort is not applicable
326    common_prefix_length: usize,
327    /// Used as a buffer for part of the input not ready for sort
328    in_mem_batches: Vec<RecordBatch>,
329    /// Fetch top N results
330    fetch: Option<usize>,
331    /// Whether the stream has finished returning all of its data or not
332    is_closed: bool,
333    /// Execution metrics
334    baseline_metrics: BaselineMetrics,
335}
336
337impl Stream for PartialSortStream {
338    type Item = Result<RecordBatch>;
339
340    fn poll_next(
341        mut self: Pin<&mut Self>,
342        cx: &mut Context<'_>,
343    ) -> Poll<Option<Self::Item>> {
344        let poll = self.poll_next_inner(cx);
345        self.baseline_metrics.record_poll(poll)
346    }
347
348    fn size_hint(&self) -> (usize, Option<usize>) {
349        // we can't predict the size of incoming batches so re-use the size hint from the input
350        self.input.size_hint()
351    }
352}
353
354impl RecordBatchStream for PartialSortStream {
355    fn schema(&self) -> SchemaRef {
356        self.input.schema()
357    }
358}
359
360impl PartialSortStream {
361    fn poll_next_inner(
362        self: &mut Pin<&mut Self>,
363        cx: &mut Context<'_>,
364    ) -> Poll<Option<Result<RecordBatch>>> {
365        if self.is_closed {
366            return Poll::Ready(None);
367        }
368        loop {
369            return Poll::Ready(match ready!(self.input.poll_next_unpin(cx)) {
370                Some(Ok(batch)) => {
371                    if let Some(slice_point) =
372                        self.get_slice_point(self.common_prefix_length, &batch)?
373                    {
374                        self.in_mem_batches.push(batch.slice(0, slice_point));
375                        let remaining_batch =
376                            batch.slice(slice_point, batch.num_rows() - slice_point);
377                        // Extract the sorted batch
378                        let sorted_batch = self.sort_in_mem_batches();
379                        // Refill with the remaining batch
380                        self.in_mem_batches.push(remaining_batch);
381
382                        debug_assert!(sorted_batch
383                            .as_ref()
384                            .map(|batch| batch.num_rows() > 0)
385                            .unwrap_or(true));
386                        Some(sorted_batch)
387                    } else {
388                        self.in_mem_batches.push(batch);
389                        continue;
390                    }
391                }
392                Some(Err(e)) => Some(Err(e)),
393                None => {
394                    self.is_closed = true;
395                    // once input is consumed, sort the rest of the inserted batches
396                    let remaining_batch = self.sort_in_mem_batches()?;
397                    if remaining_batch.num_rows() > 0 {
398                        Some(Ok(remaining_batch))
399                    } else {
400                        None
401                    }
402                }
403            });
404        }
405    }
406
407    /// Returns a sorted RecordBatch from in_mem_batches and clears in_mem_batches
408    ///
409    /// If fetch is specified for PartialSortStream `sort_in_mem_batches` will limit
410    /// the last RecordBatch returned and will mark the stream as closed
411    fn sort_in_mem_batches(self: &mut Pin<&mut Self>) -> Result<RecordBatch> {
412        let input_batch = concat_batches(&self.schema(), &self.in_mem_batches)?;
413        self.in_mem_batches.clear();
414        let result = sort_batch(&input_batch, self.expr.as_ref(), self.fetch)?;
415        if let Some(remaining_fetch) = self.fetch {
416            // remaining_fetch - result.num_rows() is always be >= 0
417            // because result length of sort_batch with limit cannot be
418            // more than the requested limit
419            self.fetch = Some(remaining_fetch - result.num_rows());
420            if remaining_fetch == result.num_rows() {
421                self.is_closed = true;
422            }
423        }
424        Ok(result)
425    }
426
427    /// Return the end index of the second last partition if the batch
428    /// can be partitioned based on its already sorted columns
429    ///
430    /// Return None if the batch cannot be partitioned, which means the
431    /// batch does not have the information for a safe sort
432    fn get_slice_point(
433        &self,
434        common_prefix_len: usize,
435        batch: &RecordBatch,
436    ) -> Result<Option<usize>> {
437        let common_prefix_sort_keys = (0..common_prefix_len)
438            .map(|idx| self.expr[idx].evaluate_to_sort_column(batch))
439            .collect::<Result<Vec<_>>>()?;
440        let partition_points =
441            evaluate_partition_ranges(batch.num_rows(), &common_prefix_sort_keys)?;
442        // If partition points are [0..100], [100..200], [200..300]
443        // we should return 200, which is the safest and furthest partition boundary
444        // Please note that we shouldn't return 300 (which is number of rows in the batch),
445        // because this boundary may change with new data.
446        if partition_points.len() >= 2 {
447            Ok(Some(partition_points[partition_points.len() - 2].end))
448        } else {
449            Ok(None)
450        }
451    }
452}
453
454#[cfg(test)]
455mod tests {
456    use std::collections::HashMap;
457
458    use arrow::array::*;
459    use arrow::compute::SortOptions;
460    use arrow::datatypes::*;
461    use futures::FutureExt;
462    use itertools::Itertools;
463
464    use datafusion_common::assert_batches_eq;
465
466    use crate::collect;
467    use crate::expressions::col;
468    use crate::expressions::PhysicalSortExpr;
469    use crate::sorts::sort::SortExec;
470    use crate::test;
471    use crate::test::assert_is_pending;
472    use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec};
473    use crate::test::TestMemoryExec;
474
475    use super::*;
476
477    #[tokio::test]
478    async fn test_partial_sort() -> Result<()> {
479        let task_ctx = Arc::new(TaskContext::default());
480        let source = test::build_table_scan_i32(
481            ("a", &vec![0, 0, 0, 1, 1, 1]),
482            ("b", &vec![1, 1, 2, 2, 3, 3]),
483            ("c", &vec![1, 0, 5, 4, 3, 2]),
484        );
485        let schema = Schema::new(vec![
486            Field::new("a", DataType::Int32, false),
487            Field::new("b", DataType::Int32, false),
488            Field::new("c", DataType::Int32, false),
489        ]);
490        let option_asc = SortOptions {
491            descending: false,
492            nulls_first: false,
493        };
494
495        let partial_sort_exec = Arc::new(PartialSortExec::new(
496            LexOrdering::new(vec![
497                PhysicalSortExpr {
498                    expr: col("a", &schema)?,
499                    options: option_asc,
500                },
501                PhysicalSortExpr {
502                    expr: col("b", &schema)?,
503                    options: option_asc,
504                },
505                PhysicalSortExpr {
506                    expr: col("c", &schema)?,
507                    options: option_asc,
508                },
509            ]),
510            Arc::clone(&source),
511            2,
512        )) as Arc<dyn ExecutionPlan>;
513
514        let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
515
516        let expected_after_sort = [
517            "+---+---+---+",
518            "| a | b | c |",
519            "+---+---+---+",
520            "| 0 | 1 | 0 |",
521            "| 0 | 1 | 1 |",
522            "| 0 | 2 | 5 |",
523            "| 1 | 2 | 4 |",
524            "| 1 | 3 | 2 |",
525            "| 1 | 3 | 3 |",
526            "+---+---+---+",
527        ];
528        assert_eq!(2, result.len());
529        assert_batches_eq!(expected_after_sort, &result);
530        assert_eq!(
531            task_ctx.runtime_env().memory_pool.reserved(),
532            0,
533            "The sort should have returned all memory used back to the memory manager"
534        );
535
536        Ok(())
537    }
538
539    #[tokio::test]
540    async fn test_partial_sort_with_fetch() -> Result<()> {
541        let task_ctx = Arc::new(TaskContext::default());
542        let source = test::build_table_scan_i32(
543            ("a", &vec![0, 0, 1, 1, 1]),
544            ("b", &vec![1, 2, 2, 3, 3]),
545            ("c", &vec![4, 3, 2, 1, 0]),
546        );
547        let schema = Schema::new(vec![
548            Field::new("a", DataType::Int32, false),
549            Field::new("b", DataType::Int32, false),
550            Field::new("c", DataType::Int32, false),
551        ]);
552        let option_asc = SortOptions {
553            descending: false,
554            nulls_first: false,
555        };
556
557        for common_prefix_length in [1, 2] {
558            let partial_sort_exec = Arc::new(
559                PartialSortExec::new(
560                    LexOrdering::new(vec![
561                        PhysicalSortExpr {
562                            expr: col("a", &schema)?,
563                            options: option_asc,
564                        },
565                        PhysicalSortExpr {
566                            expr: col("b", &schema)?,
567                            options: option_asc,
568                        },
569                        PhysicalSortExpr {
570                            expr: col("c", &schema)?,
571                            options: option_asc,
572                        },
573                    ]),
574                    Arc::clone(&source),
575                    common_prefix_length,
576                )
577                .with_fetch(Some(4)),
578            ) as Arc<dyn ExecutionPlan>;
579
580            let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
581
582            let expected_after_sort = [
583                "+---+---+---+",
584                "| a | b | c |",
585                "+---+---+---+",
586                "| 0 | 1 | 4 |",
587                "| 0 | 2 | 3 |",
588                "| 1 | 2 | 2 |",
589                "| 1 | 3 | 0 |",
590                "+---+---+---+",
591            ];
592            assert_eq!(2, result.len());
593            assert_batches_eq!(expected_after_sort, &result);
594            assert_eq!(
595                task_ctx.runtime_env().memory_pool.reserved(),
596                0,
597                "The sort should have returned all memory used back to the memory manager"
598            );
599        }
600
601        Ok(())
602    }
603
604    #[tokio::test]
605    async fn test_partial_sort2() -> Result<()> {
606        let task_ctx = Arc::new(TaskContext::default());
607        let source_tables = [
608            test::build_table_scan_i32(
609                ("a", &vec![0, 0, 0, 0, 1, 1, 1, 1]),
610                ("b", &vec![1, 1, 3, 3, 4, 4, 2, 2]),
611                ("c", &vec![7, 6, 5, 4, 3, 2, 1, 0]),
612            ),
613            test::build_table_scan_i32(
614                ("a", &vec![0, 0, 0, 0, 1, 1, 1, 1]),
615                ("b", &vec![1, 1, 3, 3, 2, 2, 4, 4]),
616                ("c", &vec![7, 6, 5, 4, 1, 0, 3, 2]),
617            ),
618        ];
619        let schema = Schema::new(vec![
620            Field::new("a", DataType::Int32, false),
621            Field::new("b", DataType::Int32, false),
622            Field::new("c", DataType::Int32, false),
623        ]);
624        let option_asc = SortOptions {
625            descending: false,
626            nulls_first: false,
627        };
628        for (common_prefix_length, source) in
629            [(1, &source_tables[0]), (2, &source_tables[1])]
630        {
631            let partial_sort_exec = Arc::new(PartialSortExec::new(
632                LexOrdering::new(vec![
633                    PhysicalSortExpr {
634                        expr: col("a", &schema)?,
635                        options: option_asc,
636                    },
637                    PhysicalSortExpr {
638                        expr: col("b", &schema)?,
639                        options: option_asc,
640                    },
641                    PhysicalSortExpr {
642                        expr: col("c", &schema)?,
643                        options: option_asc,
644                    },
645                ]),
646                Arc::clone(source),
647                common_prefix_length,
648            ));
649
650            let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
651            assert_eq!(2, result.len());
652            assert_eq!(
653                task_ctx.runtime_env().memory_pool.reserved(),
654                0,
655                "The sort should have returned all memory used back to the memory manager"
656            );
657            let expected = [
658                "+---+---+---+",
659                "| a | b | c |",
660                "+---+---+---+",
661                "| 0 | 1 | 6 |",
662                "| 0 | 1 | 7 |",
663                "| 0 | 3 | 4 |",
664                "| 0 | 3 | 5 |",
665                "| 1 | 2 | 0 |",
666                "| 1 | 2 | 1 |",
667                "| 1 | 4 | 2 |",
668                "| 1 | 4 | 3 |",
669                "+---+---+---+",
670            ];
671            assert_batches_eq!(expected, &result);
672        }
673        Ok(())
674    }
675
676    fn prepare_partitioned_input() -> Arc<dyn ExecutionPlan> {
677        let batch1 = test::build_table_i32(
678            ("a", &vec![1; 100]),
679            ("b", &(0..100).rev().collect()),
680            ("c", &(0..100).rev().collect()),
681        );
682        let batch2 = test::build_table_i32(
683            ("a", &[&vec![1; 25][..], &vec![2; 75][..]].concat()),
684            ("b", &(100..200).rev().collect()),
685            ("c", &(0..100).collect()),
686        );
687        let batch3 = test::build_table_i32(
688            ("a", &[&vec![3; 50][..], &vec![4; 50][..]].concat()),
689            ("b", &(150..250).rev().collect()),
690            ("c", &(0..100).rev().collect()),
691        );
692        let batch4 = test::build_table_i32(
693            ("a", &vec![4; 100]),
694            ("b", &(50..150).rev().collect()),
695            ("c", &(0..100).rev().collect()),
696        );
697        let schema = batch1.schema();
698
699        TestMemoryExec::try_new_exec(
700            &[vec![batch1, batch2, batch3, batch4]],
701            Arc::clone(&schema),
702            None,
703        )
704        .unwrap() as Arc<dyn ExecutionPlan>
705    }
706
707    #[tokio::test]
708    async fn test_partitioned_input_partial_sort() -> Result<()> {
709        let task_ctx = Arc::new(TaskContext::default());
710        let mem_exec = prepare_partitioned_input();
711        let option_asc = SortOptions {
712            descending: false,
713            nulls_first: false,
714        };
715        let option_desc = SortOptions {
716            descending: false,
717            nulls_first: false,
718        };
719        let schema = mem_exec.schema();
720        let partial_sort_executor = PartialSortExec::new(
721            LexOrdering::new(vec![
722                PhysicalSortExpr {
723                    expr: col("a", &schema)?,
724                    options: option_asc,
725                },
726                PhysicalSortExpr {
727                    expr: col("b", &schema)?,
728                    options: option_desc,
729                },
730                PhysicalSortExpr {
731                    expr: col("c", &schema)?,
732                    options: option_asc,
733                },
734            ]),
735            Arc::clone(&mem_exec),
736            1,
737        );
738        let partial_sort_exec =
739            Arc::new(partial_sort_executor.clone()) as Arc<dyn ExecutionPlan>;
740        let sort_exec = Arc::new(SortExec::new(
741            partial_sort_executor.expr,
742            partial_sort_executor.input,
743        )) as Arc<dyn ExecutionPlan>;
744        let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
745        assert_eq!(
746            result.iter().map(|r| r.num_rows()).collect_vec(),
747            [125, 125, 150]
748        );
749
750        assert_eq!(
751            task_ctx.runtime_env().memory_pool.reserved(),
752            0,
753            "The sort should have returned all memory used back to the memory manager"
754        );
755        let partial_sort_result = concat_batches(&schema, &result).unwrap();
756        let sort_result = collect(sort_exec, Arc::clone(&task_ctx)).await?;
757        assert_eq!(sort_result[0], partial_sort_result);
758
759        Ok(())
760    }
761
762    #[tokio::test]
763    async fn test_partitioned_input_partial_sort_with_fetch() -> Result<()> {
764        let task_ctx = Arc::new(TaskContext::default());
765        let mem_exec = prepare_partitioned_input();
766        let schema = mem_exec.schema();
767        let option_asc = SortOptions {
768            descending: false,
769            nulls_first: false,
770        };
771        let option_desc = SortOptions {
772            descending: false,
773            nulls_first: false,
774        };
775        for (fetch_size, expected_batch_num_rows) in [
776            (Some(50), vec![50]),
777            (Some(120), vec![120]),
778            (Some(150), vec![125, 25]),
779            (Some(250), vec![125, 125]),
780        ] {
781            let partial_sort_executor = PartialSortExec::new(
782                LexOrdering::new(vec![
783                    PhysicalSortExpr {
784                        expr: col("a", &schema)?,
785                        options: option_asc,
786                    },
787                    PhysicalSortExpr {
788                        expr: col("b", &schema)?,
789                        options: option_desc,
790                    },
791                    PhysicalSortExpr {
792                        expr: col("c", &schema)?,
793                        options: option_asc,
794                    },
795                ]),
796                Arc::clone(&mem_exec),
797                1,
798            )
799            .with_fetch(fetch_size);
800
801            let partial_sort_exec =
802                Arc::new(partial_sort_executor.clone()) as Arc<dyn ExecutionPlan>;
803            let sort_exec = Arc::new(
804                SortExec::new(partial_sort_executor.expr, partial_sort_executor.input)
805                    .with_fetch(fetch_size),
806            ) as Arc<dyn ExecutionPlan>;
807            let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
808            assert_eq!(
809                result.iter().map(|r| r.num_rows()).collect_vec(),
810                expected_batch_num_rows
811            );
812
813            assert_eq!(
814                task_ctx.runtime_env().memory_pool.reserved(),
815                0,
816                "The sort should have returned all memory used back to the memory manager"
817            );
818            let partial_sort_result = concat_batches(&schema, &result)?;
819            let sort_result = collect(sort_exec, Arc::clone(&task_ctx)).await?;
820            assert_eq!(sort_result[0], partial_sort_result);
821        }
822
823        Ok(())
824    }
825
826    #[tokio::test]
827    async fn test_partial_sort_no_empty_batches() -> Result<()> {
828        let task_ctx = Arc::new(TaskContext::default());
829        let mem_exec = prepare_partitioned_input();
830        let schema = mem_exec.schema();
831        let option_asc = SortOptions {
832            descending: false,
833            nulls_first: false,
834        };
835        let fetch_size = Some(250);
836        let partial_sort_executor = PartialSortExec::new(
837            LexOrdering::new(vec![
838                PhysicalSortExpr {
839                    expr: col("a", &schema)?,
840                    options: option_asc,
841                },
842                PhysicalSortExpr {
843                    expr: col("c", &schema)?,
844                    options: option_asc,
845                },
846            ]),
847            Arc::clone(&mem_exec),
848            1,
849        )
850        .with_fetch(fetch_size);
851
852        let partial_sort_exec =
853            Arc::new(partial_sort_executor.clone()) as Arc<dyn ExecutionPlan>;
854        let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
855        for rb in result {
856            assert!(rb.num_rows() > 0);
857        }
858
859        Ok(())
860    }
861
862    #[tokio::test]
863    async fn test_sort_metadata() -> Result<()> {
864        let task_ctx = Arc::new(TaskContext::default());
865        let field_metadata: HashMap<String, String> =
866            vec![("foo".to_string(), "bar".to_string())]
867                .into_iter()
868                .collect();
869        let schema_metadata: HashMap<String, String> =
870            vec![("baz".to_string(), "barf".to_string())]
871                .into_iter()
872                .collect();
873
874        let mut field = Field::new("field_name", DataType::UInt64, true);
875        field.set_metadata(field_metadata.clone());
876        let schema = Schema::new_with_metadata(vec![field], schema_metadata.clone());
877        let schema = Arc::new(schema);
878
879        let data: ArrayRef =
880            Arc::new(vec![1, 1, 2].into_iter().map(Some).collect::<UInt64Array>());
881
882        let batch = RecordBatch::try_new(Arc::clone(&schema), vec![data])?;
883        let input =
884            TestMemoryExec::try_new_exec(&[vec![batch]], Arc::clone(&schema), None)?;
885
886        let partial_sort_exec = Arc::new(PartialSortExec::new(
887            LexOrdering::new(vec![PhysicalSortExpr {
888                expr: col("field_name", &schema)?,
889                options: SortOptions::default(),
890            }]),
891            input,
892            1,
893        ));
894
895        let result: Vec<RecordBatch> = collect(partial_sort_exec, task_ctx).await?;
896        let expected_batch = vec![
897            RecordBatch::try_new(
898                Arc::clone(&schema),
899                vec![Arc::new(
900                    vec![1, 1].into_iter().map(Some).collect::<UInt64Array>(),
901                )],
902            )?,
903            RecordBatch::try_new(
904                Arc::clone(&schema),
905                vec![Arc::new(
906                    vec![2].into_iter().map(Some).collect::<UInt64Array>(),
907                )],
908            )?,
909        ];
910
911        // Data is correct
912        assert_eq!(&expected_batch, &result);
913
914        // explicitly ensure the metadata is present
915        assert_eq!(result[0].schema().fields()[0].metadata(), &field_metadata);
916        assert_eq!(result[0].schema().metadata(), &schema_metadata);
917
918        Ok(())
919    }
920
921    #[tokio::test]
922    async fn test_lex_sort_by_float() -> Result<()> {
923        let task_ctx = Arc::new(TaskContext::default());
924        let schema = Arc::new(Schema::new(vec![
925            Field::new("a", DataType::Float32, true),
926            Field::new("b", DataType::Float64, true),
927            Field::new("c", DataType::Float64, true),
928        ]));
929        let option_asc = SortOptions {
930            descending: false,
931            nulls_first: true,
932        };
933        let option_desc = SortOptions {
934            descending: true,
935            nulls_first: true,
936        };
937
938        // define data.
939        let batch = RecordBatch::try_new(
940            Arc::clone(&schema),
941            vec![
942                Arc::new(Float32Array::from(vec![
943                    Some(1.0_f32),
944                    Some(1.0_f32),
945                    Some(1.0_f32),
946                    Some(2.0_f32),
947                    Some(2.0_f32),
948                    Some(3.0_f32),
949                    Some(3.0_f32),
950                    Some(3.0_f32),
951                ])),
952                Arc::new(Float64Array::from(vec![
953                    Some(20.0_f64),
954                    Some(20.0_f64),
955                    Some(40.0_f64),
956                    Some(40.0_f64),
957                    Some(f64::NAN),
958                    None,
959                    None,
960                    Some(f64::NAN),
961                ])),
962                Arc::new(Float64Array::from(vec![
963                    Some(10.0_f64),
964                    Some(20.0_f64),
965                    Some(10.0_f64),
966                    Some(100.0_f64),
967                    Some(f64::NAN),
968                    Some(100.0_f64),
969                    None,
970                    Some(f64::NAN),
971                ])),
972            ],
973        )?;
974
975        let partial_sort_exec = Arc::new(PartialSortExec::new(
976            LexOrdering::new(vec![
977                PhysicalSortExpr {
978                    expr: col("a", &schema)?,
979                    options: option_asc,
980                },
981                PhysicalSortExpr {
982                    expr: col("b", &schema)?,
983                    options: option_asc,
984                },
985                PhysicalSortExpr {
986                    expr: col("c", &schema)?,
987                    options: option_desc,
988                },
989            ]),
990            TestMemoryExec::try_new_exec(&[vec![batch]], schema, None)?,
991            2,
992        ));
993
994        let expected = [
995            "+-----+------+-------+",
996            "| a   | b    | c     |",
997            "+-----+------+-------+",
998            "| 1.0 | 20.0 | 20.0  |",
999            "| 1.0 | 20.0 | 10.0  |",
1000            "| 1.0 | 40.0 | 10.0  |",
1001            "| 2.0 | 40.0 | 100.0 |",
1002            "| 2.0 | NaN  | NaN   |",
1003            "| 3.0 |      |       |",
1004            "| 3.0 |      | 100.0 |",
1005            "| 3.0 | NaN  | NaN   |",
1006            "+-----+------+-------+",
1007        ];
1008
1009        assert_eq!(
1010            DataType::Float32,
1011            *partial_sort_exec.schema().field(0).data_type()
1012        );
1013        assert_eq!(
1014            DataType::Float64,
1015            *partial_sort_exec.schema().field(1).data_type()
1016        );
1017        assert_eq!(
1018            DataType::Float64,
1019            *partial_sort_exec.schema().field(2).data_type()
1020        );
1021
1022        let result: Vec<RecordBatch> = collect(
1023            Arc::clone(&partial_sort_exec) as Arc<dyn ExecutionPlan>,
1024            task_ctx,
1025        )
1026        .await?;
1027        assert_batches_eq!(expected, &result);
1028        assert_eq!(result.len(), 2);
1029        let metrics = partial_sort_exec.metrics().unwrap();
1030        assert!(metrics.elapsed_compute().unwrap() > 0);
1031        assert_eq!(metrics.output_rows().unwrap(), 8);
1032
1033        let columns = result[0].columns();
1034
1035        assert_eq!(DataType::Float32, *columns[0].data_type());
1036        assert_eq!(DataType::Float64, *columns[1].data_type());
1037        assert_eq!(DataType::Float64, *columns[2].data_type());
1038
1039        Ok(())
1040    }
1041
1042    #[tokio::test]
1043    async fn test_drop_cancel() -> Result<()> {
1044        let task_ctx = Arc::new(TaskContext::default());
1045        let schema = Arc::new(Schema::new(vec![
1046            Field::new("a", DataType::Float32, true),
1047            Field::new("b", DataType::Float32, true),
1048        ]));
1049
1050        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
1051        let refs = blocking_exec.refs();
1052        let sort_exec = Arc::new(PartialSortExec::new(
1053            LexOrdering::new(vec![PhysicalSortExpr {
1054                expr: col("a", &schema)?,
1055                options: SortOptions::default(),
1056            }]),
1057            blocking_exec,
1058            1,
1059        ));
1060
1061        let fut = collect(sort_exec, Arc::clone(&task_ctx));
1062        let mut fut = fut.boxed();
1063
1064        assert_is_pending(&mut fut);
1065        drop(fut);
1066        assert_strong_count_converges_to_zero(refs).await;
1067
1068        assert_eq!(
1069            task_ctx.runtime_env().memory_pool.reserved(),
1070            0,
1071            "The sort should have returned all memory used back to the memory manager"
1072        );
1073
1074        Ok(())
1075    }
1076}