1use std::any::Any;
55use std::fmt::Debug;
56use std::pin::Pin;
57use std::sync::Arc;
58use std::task::{Context, Poll};
59
60use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
61use crate::sorts::sort::sort_batch;
62use crate::{
63 DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
64 Partitioning, PlanProperties, SendableRecordBatchStream, Statistics,
65};
66
67use arrow::compute::concat_batches;
68use arrow::datatypes::SchemaRef;
69use arrow::record_batch::RecordBatch;
70use datafusion_common::utils::evaluate_partition_ranges;
71use datafusion_common::Result;
72use datafusion_execution::{RecordBatchStream, TaskContext};
73use datafusion_physical_expr::LexOrdering;
74
75use futures::{ready, Stream, StreamExt};
76use log::trace;
77
78#[derive(Debug, Clone)]
80pub struct PartialSortExec {
81 pub(crate) input: Arc<dyn ExecutionPlan>,
83 expr: LexOrdering,
85 common_prefix_length: usize,
88 metrics_set: ExecutionPlanMetricsSet,
90 preserve_partitioning: bool,
93 fetch: Option<usize>,
95 cache: PlanProperties,
97}
98
99impl PartialSortExec {
100 pub fn new(
102 expr: LexOrdering,
103 input: Arc<dyn ExecutionPlan>,
104 common_prefix_length: usize,
105 ) -> Self {
106 debug_assert!(common_prefix_length > 0);
107 let preserve_partitioning = false;
108 let cache = Self::compute_properties(&input, expr.clone(), preserve_partitioning);
109 Self {
110 input,
111 expr,
112 common_prefix_length,
113 metrics_set: ExecutionPlanMetricsSet::new(),
114 preserve_partitioning,
115 fetch: None,
116 cache,
117 }
118 }
119
120 pub fn preserve_partitioning(&self) -> bool {
122 self.preserve_partitioning
123 }
124
125 pub fn with_preserve_partitioning(mut self, preserve_partitioning: bool) -> Self {
133 self.preserve_partitioning = preserve_partitioning;
134 self.cache = self
135 .cache
136 .with_partitioning(Self::output_partitioning_helper(
137 &self.input,
138 self.preserve_partitioning,
139 ));
140 self
141 }
142
143 pub fn with_fetch(mut self, fetch: Option<usize>) -> Self {
151 self.fetch = fetch;
152 self
153 }
154
155 pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
157 &self.input
158 }
159
160 pub fn expr(&self) -> &LexOrdering {
162 self.expr.as_ref()
163 }
164
165 pub fn fetch(&self) -> Option<usize> {
167 self.fetch
168 }
169
170 pub fn common_prefix_length(&self) -> usize {
172 self.common_prefix_length
173 }
174
175 fn output_partitioning_helper(
176 input: &Arc<dyn ExecutionPlan>,
177 preserve_partitioning: bool,
178 ) -> Partitioning {
179 if preserve_partitioning {
181 input.output_partitioning().clone()
182 } else {
183 Partitioning::UnknownPartitioning(1)
184 }
185 }
186
187 fn compute_properties(
189 input: &Arc<dyn ExecutionPlan>,
190 sort_exprs: LexOrdering,
191 preserve_partitioning: bool,
192 ) -> PlanProperties {
193 let eq_properties = input
196 .equivalence_properties()
197 .clone()
198 .with_reorder(sort_exprs);
199
200 let output_partitioning =
202 Self::output_partitioning_helper(input, preserve_partitioning);
203
204 PlanProperties::new(
205 eq_properties,
206 output_partitioning,
207 input.pipeline_behavior(),
208 input.boundedness(),
209 )
210 }
211}
212
213impl DisplayAs for PartialSortExec {
214 fn fmt_as(
215 &self,
216 t: DisplayFormatType,
217 f: &mut std::fmt::Formatter,
218 ) -> std::fmt::Result {
219 match t {
220 DisplayFormatType::Default | DisplayFormatType::Verbose => {
221 let common_prefix_length = self.common_prefix_length;
222 match self.fetch {
223 Some(fetch) => {
224 write!(f, "PartialSortExec: TopK(fetch={fetch}), expr=[{}], common_prefix_length=[{common_prefix_length}]", self.expr)
225 }
226 None => write!(f, "PartialSortExec: expr=[{}], common_prefix_length=[{common_prefix_length}]", self.expr),
227 }
228 }
229 }
230 }
231}
232
233impl ExecutionPlan for PartialSortExec {
234 fn name(&self) -> &'static str {
235 "PartialSortExec"
236 }
237
238 fn as_any(&self) -> &dyn Any {
239 self
240 }
241
242 fn properties(&self) -> &PlanProperties {
243 &self.cache
244 }
245
246 fn fetch(&self) -> Option<usize> {
247 self.fetch
248 }
249
250 fn required_input_distribution(&self) -> Vec<Distribution> {
251 if self.preserve_partitioning {
252 vec![Distribution::UnspecifiedDistribution]
253 } else {
254 vec![Distribution::SinglePartition]
255 }
256 }
257
258 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
259 vec![false]
260 }
261
262 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
263 vec![&self.input]
264 }
265
266 fn with_new_children(
267 self: Arc<Self>,
268 children: Vec<Arc<dyn ExecutionPlan>>,
269 ) -> Result<Arc<dyn ExecutionPlan>> {
270 let new_partial_sort = PartialSortExec::new(
271 self.expr.clone(),
272 Arc::clone(&children[0]),
273 self.common_prefix_length,
274 )
275 .with_fetch(self.fetch)
276 .with_preserve_partitioning(self.preserve_partitioning);
277
278 Ok(Arc::new(new_partial_sort))
279 }
280
281 fn execute(
282 &self,
283 partition: usize,
284 context: Arc<TaskContext>,
285 ) -> Result<SendableRecordBatchStream> {
286 trace!("Start PartialSortExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id());
287
288 let input = self.input.execute(partition, Arc::clone(&context))?;
289
290 trace!(
291 "End PartialSortExec's input.execute for partition: {}",
292 partition
293 );
294
295 debug_assert!(self.common_prefix_length > 0);
298
299 Ok(Box::pin(PartialSortStream {
300 input,
301 expr: self.expr.clone(),
302 common_prefix_length: self.common_prefix_length,
303 in_mem_batches: vec![],
304 fetch: self.fetch,
305 is_closed: false,
306 baseline_metrics: BaselineMetrics::new(&self.metrics_set, partition),
307 }))
308 }
309
310 fn metrics(&self) -> Option<MetricsSet> {
311 Some(self.metrics_set.clone_inner())
312 }
313
314 fn statistics(&self) -> Result<Statistics> {
315 self.input.statistics()
316 }
317}
318
319struct PartialSortStream {
320 input: SendableRecordBatchStream,
322 expr: LexOrdering,
324 common_prefix_length: usize,
327 in_mem_batches: Vec<RecordBatch>,
329 fetch: Option<usize>,
331 is_closed: bool,
333 baseline_metrics: BaselineMetrics,
335}
336
337impl Stream for PartialSortStream {
338 type Item = Result<RecordBatch>;
339
340 fn poll_next(
341 mut self: Pin<&mut Self>,
342 cx: &mut Context<'_>,
343 ) -> Poll<Option<Self::Item>> {
344 let poll = self.poll_next_inner(cx);
345 self.baseline_metrics.record_poll(poll)
346 }
347
348 fn size_hint(&self) -> (usize, Option<usize>) {
349 self.input.size_hint()
351 }
352}
353
354impl RecordBatchStream for PartialSortStream {
355 fn schema(&self) -> SchemaRef {
356 self.input.schema()
357 }
358}
359
360impl PartialSortStream {
361 fn poll_next_inner(
362 self: &mut Pin<&mut Self>,
363 cx: &mut Context<'_>,
364 ) -> Poll<Option<Result<RecordBatch>>> {
365 if self.is_closed {
366 return Poll::Ready(None);
367 }
368 loop {
369 return Poll::Ready(match ready!(self.input.poll_next_unpin(cx)) {
370 Some(Ok(batch)) => {
371 if let Some(slice_point) =
372 self.get_slice_point(self.common_prefix_length, &batch)?
373 {
374 self.in_mem_batches.push(batch.slice(0, slice_point));
375 let remaining_batch =
376 batch.slice(slice_point, batch.num_rows() - slice_point);
377 let sorted_batch = self.sort_in_mem_batches();
379 self.in_mem_batches.push(remaining_batch);
381
382 debug_assert!(sorted_batch
383 .as_ref()
384 .map(|batch| batch.num_rows() > 0)
385 .unwrap_or(true));
386 Some(sorted_batch)
387 } else {
388 self.in_mem_batches.push(batch);
389 continue;
390 }
391 }
392 Some(Err(e)) => Some(Err(e)),
393 None => {
394 self.is_closed = true;
395 let remaining_batch = self.sort_in_mem_batches()?;
397 if remaining_batch.num_rows() > 0 {
398 Some(Ok(remaining_batch))
399 } else {
400 None
401 }
402 }
403 });
404 }
405 }
406
407 fn sort_in_mem_batches(self: &mut Pin<&mut Self>) -> Result<RecordBatch> {
412 let input_batch = concat_batches(&self.schema(), &self.in_mem_batches)?;
413 self.in_mem_batches.clear();
414 let result = sort_batch(&input_batch, self.expr.as_ref(), self.fetch)?;
415 if let Some(remaining_fetch) = self.fetch {
416 self.fetch = Some(remaining_fetch - result.num_rows());
420 if remaining_fetch == result.num_rows() {
421 self.is_closed = true;
422 }
423 }
424 Ok(result)
425 }
426
427 fn get_slice_point(
433 &self,
434 common_prefix_len: usize,
435 batch: &RecordBatch,
436 ) -> Result<Option<usize>> {
437 let common_prefix_sort_keys = (0..common_prefix_len)
438 .map(|idx| self.expr[idx].evaluate_to_sort_column(batch))
439 .collect::<Result<Vec<_>>>()?;
440 let partition_points =
441 evaluate_partition_ranges(batch.num_rows(), &common_prefix_sort_keys)?;
442 if partition_points.len() >= 2 {
447 Ok(Some(partition_points[partition_points.len() - 2].end))
448 } else {
449 Ok(None)
450 }
451 }
452}
453
454#[cfg(test)]
455mod tests {
456 use std::collections::HashMap;
457
458 use arrow::array::*;
459 use arrow::compute::SortOptions;
460 use arrow::datatypes::*;
461 use futures::FutureExt;
462 use itertools::Itertools;
463
464 use datafusion_common::assert_batches_eq;
465
466 use crate::collect;
467 use crate::expressions::col;
468 use crate::expressions::PhysicalSortExpr;
469 use crate::sorts::sort::SortExec;
470 use crate::test;
471 use crate::test::assert_is_pending;
472 use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec};
473 use crate::test::TestMemoryExec;
474
475 use super::*;
476
477 #[tokio::test]
478 async fn test_partial_sort() -> Result<()> {
479 let task_ctx = Arc::new(TaskContext::default());
480 let source = test::build_table_scan_i32(
481 ("a", &vec![0, 0, 0, 1, 1, 1]),
482 ("b", &vec![1, 1, 2, 2, 3, 3]),
483 ("c", &vec![1, 0, 5, 4, 3, 2]),
484 );
485 let schema = Schema::new(vec![
486 Field::new("a", DataType::Int32, false),
487 Field::new("b", DataType::Int32, false),
488 Field::new("c", DataType::Int32, false),
489 ]);
490 let option_asc = SortOptions {
491 descending: false,
492 nulls_first: false,
493 };
494
495 let partial_sort_exec = Arc::new(PartialSortExec::new(
496 LexOrdering::new(vec![
497 PhysicalSortExpr {
498 expr: col("a", &schema)?,
499 options: option_asc,
500 },
501 PhysicalSortExpr {
502 expr: col("b", &schema)?,
503 options: option_asc,
504 },
505 PhysicalSortExpr {
506 expr: col("c", &schema)?,
507 options: option_asc,
508 },
509 ]),
510 Arc::clone(&source),
511 2,
512 )) as Arc<dyn ExecutionPlan>;
513
514 let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
515
516 let expected_after_sort = [
517 "+---+---+---+",
518 "| a | b | c |",
519 "+---+---+---+",
520 "| 0 | 1 | 0 |",
521 "| 0 | 1 | 1 |",
522 "| 0 | 2 | 5 |",
523 "| 1 | 2 | 4 |",
524 "| 1 | 3 | 2 |",
525 "| 1 | 3 | 3 |",
526 "+---+---+---+",
527 ];
528 assert_eq!(2, result.len());
529 assert_batches_eq!(expected_after_sort, &result);
530 assert_eq!(
531 task_ctx.runtime_env().memory_pool.reserved(),
532 0,
533 "The sort should have returned all memory used back to the memory manager"
534 );
535
536 Ok(())
537 }
538
539 #[tokio::test]
540 async fn test_partial_sort_with_fetch() -> Result<()> {
541 let task_ctx = Arc::new(TaskContext::default());
542 let source = test::build_table_scan_i32(
543 ("a", &vec![0, 0, 1, 1, 1]),
544 ("b", &vec![1, 2, 2, 3, 3]),
545 ("c", &vec![4, 3, 2, 1, 0]),
546 );
547 let schema = Schema::new(vec![
548 Field::new("a", DataType::Int32, false),
549 Field::new("b", DataType::Int32, false),
550 Field::new("c", DataType::Int32, false),
551 ]);
552 let option_asc = SortOptions {
553 descending: false,
554 nulls_first: false,
555 };
556
557 for common_prefix_length in [1, 2] {
558 let partial_sort_exec = Arc::new(
559 PartialSortExec::new(
560 LexOrdering::new(vec![
561 PhysicalSortExpr {
562 expr: col("a", &schema)?,
563 options: option_asc,
564 },
565 PhysicalSortExpr {
566 expr: col("b", &schema)?,
567 options: option_asc,
568 },
569 PhysicalSortExpr {
570 expr: col("c", &schema)?,
571 options: option_asc,
572 },
573 ]),
574 Arc::clone(&source),
575 common_prefix_length,
576 )
577 .with_fetch(Some(4)),
578 ) as Arc<dyn ExecutionPlan>;
579
580 let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
581
582 let expected_after_sort = [
583 "+---+---+---+",
584 "| a | b | c |",
585 "+---+---+---+",
586 "| 0 | 1 | 4 |",
587 "| 0 | 2 | 3 |",
588 "| 1 | 2 | 2 |",
589 "| 1 | 3 | 0 |",
590 "+---+---+---+",
591 ];
592 assert_eq!(2, result.len());
593 assert_batches_eq!(expected_after_sort, &result);
594 assert_eq!(
595 task_ctx.runtime_env().memory_pool.reserved(),
596 0,
597 "The sort should have returned all memory used back to the memory manager"
598 );
599 }
600
601 Ok(())
602 }
603
604 #[tokio::test]
605 async fn test_partial_sort2() -> Result<()> {
606 let task_ctx = Arc::new(TaskContext::default());
607 let source_tables = [
608 test::build_table_scan_i32(
609 ("a", &vec![0, 0, 0, 0, 1, 1, 1, 1]),
610 ("b", &vec![1, 1, 3, 3, 4, 4, 2, 2]),
611 ("c", &vec![7, 6, 5, 4, 3, 2, 1, 0]),
612 ),
613 test::build_table_scan_i32(
614 ("a", &vec![0, 0, 0, 0, 1, 1, 1, 1]),
615 ("b", &vec![1, 1, 3, 3, 2, 2, 4, 4]),
616 ("c", &vec![7, 6, 5, 4, 1, 0, 3, 2]),
617 ),
618 ];
619 let schema = Schema::new(vec![
620 Field::new("a", DataType::Int32, false),
621 Field::new("b", DataType::Int32, false),
622 Field::new("c", DataType::Int32, false),
623 ]);
624 let option_asc = SortOptions {
625 descending: false,
626 nulls_first: false,
627 };
628 for (common_prefix_length, source) in
629 [(1, &source_tables[0]), (2, &source_tables[1])]
630 {
631 let partial_sort_exec = Arc::new(PartialSortExec::new(
632 LexOrdering::new(vec![
633 PhysicalSortExpr {
634 expr: col("a", &schema)?,
635 options: option_asc,
636 },
637 PhysicalSortExpr {
638 expr: col("b", &schema)?,
639 options: option_asc,
640 },
641 PhysicalSortExpr {
642 expr: col("c", &schema)?,
643 options: option_asc,
644 },
645 ]),
646 Arc::clone(source),
647 common_prefix_length,
648 ));
649
650 let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
651 assert_eq!(2, result.len());
652 assert_eq!(
653 task_ctx.runtime_env().memory_pool.reserved(),
654 0,
655 "The sort should have returned all memory used back to the memory manager"
656 );
657 let expected = [
658 "+---+---+---+",
659 "| a | b | c |",
660 "+---+---+---+",
661 "| 0 | 1 | 6 |",
662 "| 0 | 1 | 7 |",
663 "| 0 | 3 | 4 |",
664 "| 0 | 3 | 5 |",
665 "| 1 | 2 | 0 |",
666 "| 1 | 2 | 1 |",
667 "| 1 | 4 | 2 |",
668 "| 1 | 4 | 3 |",
669 "+---+---+---+",
670 ];
671 assert_batches_eq!(expected, &result);
672 }
673 Ok(())
674 }
675
676 fn prepare_partitioned_input() -> Arc<dyn ExecutionPlan> {
677 let batch1 = test::build_table_i32(
678 ("a", &vec![1; 100]),
679 ("b", &(0..100).rev().collect()),
680 ("c", &(0..100).rev().collect()),
681 );
682 let batch2 = test::build_table_i32(
683 ("a", &[&vec![1; 25][..], &vec![2; 75][..]].concat()),
684 ("b", &(100..200).rev().collect()),
685 ("c", &(0..100).collect()),
686 );
687 let batch3 = test::build_table_i32(
688 ("a", &[&vec![3; 50][..], &vec![4; 50][..]].concat()),
689 ("b", &(150..250).rev().collect()),
690 ("c", &(0..100).rev().collect()),
691 );
692 let batch4 = test::build_table_i32(
693 ("a", &vec![4; 100]),
694 ("b", &(50..150).rev().collect()),
695 ("c", &(0..100).rev().collect()),
696 );
697 let schema = batch1.schema();
698
699 TestMemoryExec::try_new_exec(
700 &[vec![batch1, batch2, batch3, batch4]],
701 Arc::clone(&schema),
702 None,
703 )
704 .unwrap() as Arc<dyn ExecutionPlan>
705 }
706
707 #[tokio::test]
708 async fn test_partitioned_input_partial_sort() -> Result<()> {
709 let task_ctx = Arc::new(TaskContext::default());
710 let mem_exec = prepare_partitioned_input();
711 let option_asc = SortOptions {
712 descending: false,
713 nulls_first: false,
714 };
715 let option_desc = SortOptions {
716 descending: false,
717 nulls_first: false,
718 };
719 let schema = mem_exec.schema();
720 let partial_sort_executor = PartialSortExec::new(
721 LexOrdering::new(vec![
722 PhysicalSortExpr {
723 expr: col("a", &schema)?,
724 options: option_asc,
725 },
726 PhysicalSortExpr {
727 expr: col("b", &schema)?,
728 options: option_desc,
729 },
730 PhysicalSortExpr {
731 expr: col("c", &schema)?,
732 options: option_asc,
733 },
734 ]),
735 Arc::clone(&mem_exec),
736 1,
737 );
738 let partial_sort_exec =
739 Arc::new(partial_sort_executor.clone()) as Arc<dyn ExecutionPlan>;
740 let sort_exec = Arc::new(SortExec::new(
741 partial_sort_executor.expr,
742 partial_sort_executor.input,
743 )) as Arc<dyn ExecutionPlan>;
744 let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
745 assert_eq!(
746 result.iter().map(|r| r.num_rows()).collect_vec(),
747 [125, 125, 150]
748 );
749
750 assert_eq!(
751 task_ctx.runtime_env().memory_pool.reserved(),
752 0,
753 "The sort should have returned all memory used back to the memory manager"
754 );
755 let partial_sort_result = concat_batches(&schema, &result).unwrap();
756 let sort_result = collect(sort_exec, Arc::clone(&task_ctx)).await?;
757 assert_eq!(sort_result[0], partial_sort_result);
758
759 Ok(())
760 }
761
762 #[tokio::test]
763 async fn test_partitioned_input_partial_sort_with_fetch() -> Result<()> {
764 let task_ctx = Arc::new(TaskContext::default());
765 let mem_exec = prepare_partitioned_input();
766 let schema = mem_exec.schema();
767 let option_asc = SortOptions {
768 descending: false,
769 nulls_first: false,
770 };
771 let option_desc = SortOptions {
772 descending: false,
773 nulls_first: false,
774 };
775 for (fetch_size, expected_batch_num_rows) in [
776 (Some(50), vec![50]),
777 (Some(120), vec![120]),
778 (Some(150), vec![125, 25]),
779 (Some(250), vec![125, 125]),
780 ] {
781 let partial_sort_executor = PartialSortExec::new(
782 LexOrdering::new(vec![
783 PhysicalSortExpr {
784 expr: col("a", &schema)?,
785 options: option_asc,
786 },
787 PhysicalSortExpr {
788 expr: col("b", &schema)?,
789 options: option_desc,
790 },
791 PhysicalSortExpr {
792 expr: col("c", &schema)?,
793 options: option_asc,
794 },
795 ]),
796 Arc::clone(&mem_exec),
797 1,
798 )
799 .with_fetch(fetch_size);
800
801 let partial_sort_exec =
802 Arc::new(partial_sort_executor.clone()) as Arc<dyn ExecutionPlan>;
803 let sort_exec = Arc::new(
804 SortExec::new(partial_sort_executor.expr, partial_sort_executor.input)
805 .with_fetch(fetch_size),
806 ) as Arc<dyn ExecutionPlan>;
807 let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
808 assert_eq!(
809 result.iter().map(|r| r.num_rows()).collect_vec(),
810 expected_batch_num_rows
811 );
812
813 assert_eq!(
814 task_ctx.runtime_env().memory_pool.reserved(),
815 0,
816 "The sort should have returned all memory used back to the memory manager"
817 );
818 let partial_sort_result = concat_batches(&schema, &result)?;
819 let sort_result = collect(sort_exec, Arc::clone(&task_ctx)).await?;
820 assert_eq!(sort_result[0], partial_sort_result);
821 }
822
823 Ok(())
824 }
825
826 #[tokio::test]
827 async fn test_partial_sort_no_empty_batches() -> Result<()> {
828 let task_ctx = Arc::new(TaskContext::default());
829 let mem_exec = prepare_partitioned_input();
830 let schema = mem_exec.schema();
831 let option_asc = SortOptions {
832 descending: false,
833 nulls_first: false,
834 };
835 let fetch_size = Some(250);
836 let partial_sort_executor = PartialSortExec::new(
837 LexOrdering::new(vec![
838 PhysicalSortExpr {
839 expr: col("a", &schema)?,
840 options: option_asc,
841 },
842 PhysicalSortExpr {
843 expr: col("c", &schema)?,
844 options: option_asc,
845 },
846 ]),
847 Arc::clone(&mem_exec),
848 1,
849 )
850 .with_fetch(fetch_size);
851
852 let partial_sort_exec =
853 Arc::new(partial_sort_executor.clone()) as Arc<dyn ExecutionPlan>;
854 let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
855 for rb in result {
856 assert!(rb.num_rows() > 0);
857 }
858
859 Ok(())
860 }
861
862 #[tokio::test]
863 async fn test_sort_metadata() -> Result<()> {
864 let task_ctx = Arc::new(TaskContext::default());
865 let field_metadata: HashMap<String, String> =
866 vec![("foo".to_string(), "bar".to_string())]
867 .into_iter()
868 .collect();
869 let schema_metadata: HashMap<String, String> =
870 vec![("baz".to_string(), "barf".to_string())]
871 .into_iter()
872 .collect();
873
874 let mut field = Field::new("field_name", DataType::UInt64, true);
875 field.set_metadata(field_metadata.clone());
876 let schema = Schema::new_with_metadata(vec![field], schema_metadata.clone());
877 let schema = Arc::new(schema);
878
879 let data: ArrayRef =
880 Arc::new(vec![1, 1, 2].into_iter().map(Some).collect::<UInt64Array>());
881
882 let batch = RecordBatch::try_new(Arc::clone(&schema), vec![data])?;
883 let input =
884 TestMemoryExec::try_new_exec(&[vec![batch]], Arc::clone(&schema), None)?;
885
886 let partial_sort_exec = Arc::new(PartialSortExec::new(
887 LexOrdering::new(vec![PhysicalSortExpr {
888 expr: col("field_name", &schema)?,
889 options: SortOptions::default(),
890 }]),
891 input,
892 1,
893 ));
894
895 let result: Vec<RecordBatch> = collect(partial_sort_exec, task_ctx).await?;
896 let expected_batch = vec![
897 RecordBatch::try_new(
898 Arc::clone(&schema),
899 vec![Arc::new(
900 vec![1, 1].into_iter().map(Some).collect::<UInt64Array>(),
901 )],
902 )?,
903 RecordBatch::try_new(
904 Arc::clone(&schema),
905 vec![Arc::new(
906 vec![2].into_iter().map(Some).collect::<UInt64Array>(),
907 )],
908 )?,
909 ];
910
911 assert_eq!(&expected_batch, &result);
913
914 assert_eq!(result[0].schema().fields()[0].metadata(), &field_metadata);
916 assert_eq!(result[0].schema().metadata(), &schema_metadata);
917
918 Ok(())
919 }
920
921 #[tokio::test]
922 async fn test_lex_sort_by_float() -> Result<()> {
923 let task_ctx = Arc::new(TaskContext::default());
924 let schema = Arc::new(Schema::new(vec![
925 Field::new("a", DataType::Float32, true),
926 Field::new("b", DataType::Float64, true),
927 Field::new("c", DataType::Float64, true),
928 ]));
929 let option_asc = SortOptions {
930 descending: false,
931 nulls_first: true,
932 };
933 let option_desc = SortOptions {
934 descending: true,
935 nulls_first: true,
936 };
937
938 let batch = RecordBatch::try_new(
940 Arc::clone(&schema),
941 vec![
942 Arc::new(Float32Array::from(vec![
943 Some(1.0_f32),
944 Some(1.0_f32),
945 Some(1.0_f32),
946 Some(2.0_f32),
947 Some(2.0_f32),
948 Some(3.0_f32),
949 Some(3.0_f32),
950 Some(3.0_f32),
951 ])),
952 Arc::new(Float64Array::from(vec![
953 Some(20.0_f64),
954 Some(20.0_f64),
955 Some(40.0_f64),
956 Some(40.0_f64),
957 Some(f64::NAN),
958 None,
959 None,
960 Some(f64::NAN),
961 ])),
962 Arc::new(Float64Array::from(vec![
963 Some(10.0_f64),
964 Some(20.0_f64),
965 Some(10.0_f64),
966 Some(100.0_f64),
967 Some(f64::NAN),
968 Some(100.0_f64),
969 None,
970 Some(f64::NAN),
971 ])),
972 ],
973 )?;
974
975 let partial_sort_exec = Arc::new(PartialSortExec::new(
976 LexOrdering::new(vec![
977 PhysicalSortExpr {
978 expr: col("a", &schema)?,
979 options: option_asc,
980 },
981 PhysicalSortExpr {
982 expr: col("b", &schema)?,
983 options: option_asc,
984 },
985 PhysicalSortExpr {
986 expr: col("c", &schema)?,
987 options: option_desc,
988 },
989 ]),
990 TestMemoryExec::try_new_exec(&[vec![batch]], schema, None)?,
991 2,
992 ));
993
994 let expected = [
995 "+-----+------+-------+",
996 "| a | b | c |",
997 "+-----+------+-------+",
998 "| 1.0 | 20.0 | 20.0 |",
999 "| 1.0 | 20.0 | 10.0 |",
1000 "| 1.0 | 40.0 | 10.0 |",
1001 "| 2.0 | 40.0 | 100.0 |",
1002 "| 2.0 | NaN | NaN |",
1003 "| 3.0 | | |",
1004 "| 3.0 | | 100.0 |",
1005 "| 3.0 | NaN | NaN |",
1006 "+-----+------+-------+",
1007 ];
1008
1009 assert_eq!(
1010 DataType::Float32,
1011 *partial_sort_exec.schema().field(0).data_type()
1012 );
1013 assert_eq!(
1014 DataType::Float64,
1015 *partial_sort_exec.schema().field(1).data_type()
1016 );
1017 assert_eq!(
1018 DataType::Float64,
1019 *partial_sort_exec.schema().field(2).data_type()
1020 );
1021
1022 let result: Vec<RecordBatch> = collect(
1023 Arc::clone(&partial_sort_exec) as Arc<dyn ExecutionPlan>,
1024 task_ctx,
1025 )
1026 .await?;
1027 assert_batches_eq!(expected, &result);
1028 assert_eq!(result.len(), 2);
1029 let metrics = partial_sort_exec.metrics().unwrap();
1030 assert!(metrics.elapsed_compute().unwrap() > 0);
1031 assert_eq!(metrics.output_rows().unwrap(), 8);
1032
1033 let columns = result[0].columns();
1034
1035 assert_eq!(DataType::Float32, *columns[0].data_type());
1036 assert_eq!(DataType::Float64, *columns[1].data_type());
1037 assert_eq!(DataType::Float64, *columns[2].data_type());
1038
1039 Ok(())
1040 }
1041
1042 #[tokio::test]
1043 async fn test_drop_cancel() -> Result<()> {
1044 let task_ctx = Arc::new(TaskContext::default());
1045 let schema = Arc::new(Schema::new(vec![
1046 Field::new("a", DataType::Float32, true),
1047 Field::new("b", DataType::Float32, true),
1048 ]));
1049
1050 let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
1051 let refs = blocking_exec.refs();
1052 let sort_exec = Arc::new(PartialSortExec::new(
1053 LexOrdering::new(vec![PhysicalSortExpr {
1054 expr: col("a", &schema)?,
1055 options: SortOptions::default(),
1056 }]),
1057 blocking_exec,
1058 1,
1059 ));
1060
1061 let fut = collect(sort_exec, Arc::clone(&task_ctx));
1062 let mut fut = fut.boxed();
1063
1064 assert_is_pending(&mut fut);
1065 drop(fut);
1066 assert_strong_count_converges_to_zero(refs).await;
1067
1068 assert_eq!(
1069 task_ctx.runtime_env().memory_pool.reserved(),
1070 0,
1071 "The sort should have returned all memory used back to the memory manager"
1072 );
1073
1074 Ok(())
1075 }
1076}