1use std::any::Any;
21use std::sync::Arc;
22
23use crate::common::spawn_buffered;
24use crate::limit::LimitStream;
25use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
26use crate::projection::{make_with_child, update_expr, ProjectionExec};
27use crate::sorts::streaming_merge::StreamingMergeBuilder;
28use crate::{
29 DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
30 Partitioning, PlanProperties, SendableRecordBatchStream, Statistics,
31};
32
33use datafusion_common::{internal_err, Result};
34use datafusion_execution::memory_pool::MemoryConsumer;
35use datafusion_execution::TaskContext;
36use datafusion_physical_expr::PhysicalSortExpr;
37use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement};
38
39use log::{debug, trace};
40
41#[derive(Debug, Clone)]
86pub struct SortPreservingMergeExec {
87 input: Arc<dyn ExecutionPlan>,
89 expr: LexOrdering,
91 metrics: ExecutionPlanMetricsSet,
93 fetch: Option<usize>,
95 cache: PlanProperties,
97 enable_round_robin_repartition: bool,
101}
102
103impl SortPreservingMergeExec {
104 pub fn new(expr: LexOrdering, input: Arc<dyn ExecutionPlan>) -> Self {
106 let cache = Self::compute_properties(&input, expr.clone());
107 Self {
108 input,
109 expr,
110 metrics: ExecutionPlanMetricsSet::new(),
111 fetch: None,
112 cache,
113 enable_round_robin_repartition: true,
114 }
115 }
116
117 pub fn with_fetch(mut self, fetch: Option<usize>) -> Self {
119 self.fetch = fetch;
120 self
121 }
122
123 pub fn with_round_robin_repartition(
133 mut self,
134 enable_round_robin_repartition: bool,
135 ) -> Self {
136 self.enable_round_robin_repartition = enable_round_robin_repartition;
137 self
138 }
139
140 pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
142 &self.input
143 }
144
145 pub fn expr(&self) -> &LexOrdering {
147 self.expr.as_ref()
148 }
149
150 pub fn fetch(&self) -> Option<usize> {
152 self.fetch
153 }
154
155 fn compute_properties(
158 input: &Arc<dyn ExecutionPlan>,
159 ordering: LexOrdering,
160 ) -> PlanProperties {
161 let mut eq_properties = input.equivalence_properties().clone();
162 eq_properties.clear_per_partition_constants();
163 eq_properties.add_new_orderings(vec![ordering]);
164 PlanProperties::new(
165 eq_properties, Partitioning::UnknownPartitioning(1), input.pipeline_behavior(), input.boundedness(), )
170 }
171}
172
173impl DisplayAs for SortPreservingMergeExec {
174 fn fmt_as(
175 &self,
176 t: DisplayFormatType,
177 f: &mut std::fmt::Formatter,
178 ) -> std::fmt::Result {
179 match t {
180 DisplayFormatType::Default | DisplayFormatType::Verbose => {
181 write!(f, "SortPreservingMergeExec: [{}]", self.expr)?;
182 if let Some(fetch) = self.fetch {
183 write!(f, ", fetch={fetch}")?;
184 };
185
186 Ok(())
187 }
188 }
189 }
190}
191
192impl ExecutionPlan for SortPreservingMergeExec {
193 fn name(&self) -> &'static str {
194 "SortPreservingMergeExec"
195 }
196
197 fn as_any(&self) -> &dyn Any {
199 self
200 }
201
202 fn properties(&self) -> &PlanProperties {
203 &self.cache
204 }
205
206 fn fetch(&self) -> Option<usize> {
207 self.fetch
208 }
209
210 fn with_fetch(&self, limit: Option<usize>) -> Option<Arc<dyn ExecutionPlan>> {
212 Some(Arc::new(Self {
213 input: Arc::clone(&self.input),
214 expr: self.expr.clone(),
215 metrics: self.metrics.clone(),
216 fetch: limit,
217 cache: self.cache.clone(),
218 enable_round_robin_repartition: true,
219 }))
220 }
221
222 fn required_input_distribution(&self) -> Vec<Distribution> {
223 vec![Distribution::UnspecifiedDistribution]
224 }
225
226 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
227 vec![false]
228 }
229
230 fn required_input_ordering(&self) -> Vec<Option<LexRequirement>> {
231 vec![Some(LexRequirement::from(self.expr.clone()))]
232 }
233
234 fn maintains_input_order(&self) -> Vec<bool> {
235 vec![true]
236 }
237
238 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
239 vec![&self.input]
240 }
241
242 fn with_new_children(
243 self: Arc<Self>,
244 children: Vec<Arc<dyn ExecutionPlan>>,
245 ) -> Result<Arc<dyn ExecutionPlan>> {
246 Ok(Arc::new(
247 SortPreservingMergeExec::new(self.expr.clone(), Arc::clone(&children[0]))
248 .with_fetch(self.fetch),
249 ))
250 }
251
252 fn execute(
253 &self,
254 partition: usize,
255 context: Arc<TaskContext>,
256 ) -> Result<SendableRecordBatchStream> {
257 trace!(
258 "Start SortPreservingMergeExec::execute for partition: {}",
259 partition
260 );
261 if 0 != partition {
262 return internal_err!(
263 "SortPreservingMergeExec invalid partition {partition}"
264 );
265 }
266
267 let input_partitions = self.input.output_partitioning().partition_count();
268 trace!(
269 "Number of input partitions of SortPreservingMergeExec::execute: {}",
270 input_partitions
271 );
272 let schema = self.schema();
273
274 let reservation =
275 MemoryConsumer::new(format!("SortPreservingMergeExec[{partition}]"))
276 .register(&context.runtime_env().memory_pool);
277
278 match input_partitions {
279 0 => internal_err!(
280 "SortPreservingMergeExec requires at least one input partition"
281 ),
282 1 => match self.fetch {
283 Some(fetch) => {
284 let stream = self.input.execute(0, context)?;
285 debug!("Done getting stream for SortPreservingMergeExec::execute with 1 input with {fetch}");
286 Ok(Box::pin(LimitStream::new(
287 stream,
288 0,
289 Some(fetch),
290 BaselineMetrics::new(&self.metrics, partition),
291 )))
292 }
293 None => {
294 let stream = self.input.execute(0, context);
295 debug!("Done getting stream for SortPreservingMergeExec::execute with 1 input without fetch");
296 stream
297 }
298 },
299 _ => {
300 let receivers = (0..input_partitions)
301 .map(|partition| {
302 let stream =
303 self.input.execute(partition, Arc::clone(&context))?;
304 Ok(spawn_buffered(stream, 1))
305 })
306 .collect::<Result<_>>()?;
307
308 debug!("Done setting up sender-receiver for SortPreservingMergeExec::execute");
309
310 let result = StreamingMergeBuilder::new()
311 .with_streams(receivers)
312 .with_schema(schema)
313 .with_expressions(self.expr.as_ref())
314 .with_metrics(BaselineMetrics::new(&self.metrics, partition))
315 .with_batch_size(context.session_config().batch_size())
316 .with_fetch(self.fetch)
317 .with_reservation(reservation)
318 .with_round_robin_tie_breaker(self.enable_round_robin_repartition)
319 .build()?;
320
321 debug!("Got stream result from SortPreservingMergeStream::new_from_receivers");
322
323 Ok(result)
324 }
325 }
326 }
327
328 fn metrics(&self) -> Option<MetricsSet> {
329 Some(self.metrics.clone_inner())
330 }
331
332 fn statistics(&self) -> Result<Statistics> {
333 self.input.statistics()
334 }
335
336 fn supports_limit_pushdown(&self) -> bool {
337 true
338 }
339
340 fn try_swapping_with_projection(
344 &self,
345 projection: &ProjectionExec,
346 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
347 if projection.expr().len() >= projection.input().schema().fields().len() {
349 return Ok(None);
350 }
351
352 let mut updated_exprs = LexOrdering::default();
353 for sort in self.expr() {
354 let Some(updated_expr) = update_expr(&sort.expr, projection.expr(), false)?
355 else {
356 return Ok(None);
357 };
358 updated_exprs.push(PhysicalSortExpr {
359 expr: updated_expr,
360 options: sort.options,
361 });
362 }
363
364 Ok(Some(Arc::new(
365 SortPreservingMergeExec::new(
366 updated_exprs,
367 make_with_child(projection, self.input())?,
368 )
369 .with_fetch(self.fetch()),
370 )))
371 }
372}
373
374#[cfg(test)]
375mod tests {
376 use std::fmt::Formatter;
377 use std::pin::Pin;
378 use std::sync::Mutex;
379 use std::task::{Context, Poll};
380 use std::time::Duration;
381
382 use super::*;
383 use crate::coalesce_batches::CoalesceBatchesExec;
384 use crate::coalesce_partitions::CoalescePartitionsExec;
385 use crate::execution_plan::{Boundedness, EmissionType};
386 use crate::expressions::col;
387 use crate::metrics::{MetricValue, Timestamp};
388 use crate::repartition::RepartitionExec;
389 use crate::sorts::sort::SortExec;
390 use crate::stream::RecordBatchReceiverStream;
391 use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec};
392 use crate::test::TestMemoryExec;
393 use crate::test::{self, assert_is_pending, make_partition};
394 use crate::{collect, common};
395
396 use arrow::array::{
397 ArrayRef, Int32Array, Int64Array, RecordBatch, StringArray,
398 TimestampNanosecondArray,
399 };
400 use arrow::compute::SortOptions;
401 use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
402 use datafusion_common::{assert_batches_eq, assert_contains, DataFusionError};
403 use datafusion_common_runtime::SpawnedTask;
404 use datafusion_execution::config::SessionConfig;
405 use datafusion_execution::runtime_env::RuntimeEnvBuilder;
406 use datafusion_execution::RecordBatchStream;
407 use datafusion_physical_expr::expressions::Column;
408 use datafusion_physical_expr::EquivalenceProperties;
409 use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
410
411 use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
412 use futures::{FutureExt, Stream, StreamExt};
413 use tokio::time::timeout;
414
415 fn generate_task_ctx_for_round_robin_tie_breaker() -> Result<Arc<TaskContext>> {
418 let runtime = RuntimeEnvBuilder::new()
419 .with_memory_limit(20_000_000, 1.0)
420 .build_arc()?;
421 let config = SessionConfig::new();
422 let task_ctx = TaskContext::default()
423 .with_runtime(runtime)
424 .with_session_config(config);
425 Ok(Arc::new(task_ctx))
426 }
427 fn generate_spm_for_round_robin_tie_breaker(
430 enable_round_robin_repartition: bool,
431 ) -> Result<Arc<SortPreservingMergeExec>> {
432 let target_batch_size = 12500;
433 let row_size = 12500;
434 let a: ArrayRef = Arc::new(Int32Array::from(vec![1; row_size]));
435 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("a"); row_size]));
436 let c: ArrayRef = Arc::new(Int64Array::from_iter(vec![0; row_size]));
437 let rb = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
438
439 let rbs = (0..1024).map(|_| rb.clone()).collect::<Vec<_>>();
440
441 let schema = rb.schema();
442 let sort = LexOrdering::new(vec![
443 PhysicalSortExpr {
444 expr: col("b", &schema).unwrap(),
445 options: Default::default(),
446 },
447 PhysicalSortExpr {
448 expr: col("c", &schema).unwrap(),
449 options: Default::default(),
450 },
451 ]);
452
453 let repartition_exec = RepartitionExec::try_new(
454 TestMemoryExec::try_new_exec(&[rbs], schema, None).unwrap(),
455 Partitioning::RoundRobinBatch(2),
456 )?;
457 let coalesce_batches_exec =
458 CoalesceBatchesExec::new(Arc::new(repartition_exec), target_batch_size);
459 let spm = SortPreservingMergeExec::new(sort, Arc::new(coalesce_batches_exec))
460 .with_round_robin_repartition(enable_round_robin_repartition);
461 Ok(Arc::new(spm))
462 }
463
464 #[tokio::test(flavor = "multi_thread")]
470 async fn test_round_robin_tie_breaker_success() -> Result<()> {
471 let task_ctx = generate_task_ctx_for_round_robin_tie_breaker()?;
472 let spm = generate_spm_for_round_robin_tie_breaker(true)?;
473 let _collected = collect(spm, task_ctx).await.unwrap();
474 Ok(())
475 }
476
477 #[tokio::test(flavor = "multi_thread")]
483 async fn test_round_robin_tie_breaker_fail() -> Result<()> {
484 let task_ctx = generate_task_ctx_for_round_robin_tie_breaker()?;
485 let spm = generate_spm_for_round_robin_tie_breaker(false)?;
486 let _err = collect(spm, task_ctx).await.unwrap_err();
487 Ok(())
488 }
489
490 #[tokio::test]
491 async fn test_merge_interleave() {
492 let task_ctx = Arc::new(TaskContext::default());
493 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
494 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
495 Some("a"),
496 Some("c"),
497 Some("e"),
498 Some("g"),
499 Some("j"),
500 ]));
501 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
502 let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
503
504 let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 70, 90, 30]));
505 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
506 Some("b"),
507 Some("d"),
508 Some("f"),
509 Some("h"),
510 Some("j"),
511 ]));
512 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
513 let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
514
515 _test_merge(
516 &[vec![b1], vec![b2]],
517 &[
518 "+----+---+-------------------------------+",
519 "| a | b | c |",
520 "+----+---+-------------------------------+",
521 "| 1 | a | 1970-01-01T00:00:00.000000008 |",
522 "| 10 | b | 1970-01-01T00:00:00.000000004 |",
523 "| 2 | c | 1970-01-01T00:00:00.000000007 |",
524 "| 20 | d | 1970-01-01T00:00:00.000000006 |",
525 "| 7 | e | 1970-01-01T00:00:00.000000006 |",
526 "| 70 | f | 1970-01-01T00:00:00.000000002 |",
527 "| 9 | g | 1970-01-01T00:00:00.000000005 |",
528 "| 90 | h | 1970-01-01T00:00:00.000000002 |",
529 "| 30 | j | 1970-01-01T00:00:00.000000006 |", "| 3 | j | 1970-01-01T00:00:00.000000008 |",
531 "+----+---+-------------------------------+",
532 ],
533 task_ctx,
534 )
535 .await;
536 }
537
538 #[tokio::test]
539 async fn test_merge_no_exprs() {
540 let task_ctx = Arc::new(TaskContext::default());
541 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
542 let batch = RecordBatch::try_from_iter(vec![("a", a)]).unwrap();
543
544 let schema = batch.schema();
545 let sort = LexOrdering::default(); let exec = TestMemoryExec::try_new_exec(
547 &[vec![batch.clone()], vec![batch]],
548 schema,
549 None,
550 )
551 .unwrap();
552
553 let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
554
555 let res = collect(merge, task_ctx).await.unwrap_err();
556 assert_contains!(
557 res.to_string(),
558 "Internal error: Sort expressions cannot be empty for streaming merge"
559 );
560 }
561
562 #[tokio::test]
563 async fn test_merge_some_overlap() {
564 let task_ctx = Arc::new(TaskContext::default());
565 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
566 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
567 Some("a"),
568 Some("b"),
569 Some("c"),
570 Some("d"),
571 Some("e"),
572 ]));
573 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
574 let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
575
576 let a: ArrayRef = Arc::new(Int32Array::from(vec![70, 90, 30, 100, 110]));
577 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
578 Some("c"),
579 Some("d"),
580 Some("e"),
581 Some("f"),
582 Some("g"),
583 ]));
584 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
585 let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
586
587 _test_merge(
588 &[vec![b1], vec![b2]],
589 &[
590 "+-----+---+-------------------------------+",
591 "| a | b | c |",
592 "+-----+---+-------------------------------+",
593 "| 1 | a | 1970-01-01T00:00:00.000000008 |",
594 "| 2 | b | 1970-01-01T00:00:00.000000007 |",
595 "| 70 | c | 1970-01-01T00:00:00.000000004 |",
596 "| 7 | c | 1970-01-01T00:00:00.000000006 |",
597 "| 9 | d | 1970-01-01T00:00:00.000000005 |",
598 "| 90 | d | 1970-01-01T00:00:00.000000006 |",
599 "| 30 | e | 1970-01-01T00:00:00.000000002 |",
600 "| 3 | e | 1970-01-01T00:00:00.000000008 |",
601 "| 100 | f | 1970-01-01T00:00:00.000000002 |",
602 "| 110 | g | 1970-01-01T00:00:00.000000006 |",
603 "+-----+---+-------------------------------+",
604 ],
605 task_ctx,
606 )
607 .await;
608 }
609
610 #[tokio::test]
611 async fn test_merge_no_overlap() {
612 let task_ctx = Arc::new(TaskContext::default());
613 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
614 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
615 Some("a"),
616 Some("b"),
617 Some("c"),
618 Some("d"),
619 Some("e"),
620 ]));
621 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
622 let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
623
624 let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 70, 90, 30]));
625 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
626 Some("f"),
627 Some("g"),
628 Some("h"),
629 Some("i"),
630 Some("j"),
631 ]));
632 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
633 let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
634
635 _test_merge(
636 &[vec![b1], vec![b2]],
637 &[
638 "+----+---+-------------------------------+",
639 "| a | b | c |",
640 "+----+---+-------------------------------+",
641 "| 1 | a | 1970-01-01T00:00:00.000000008 |",
642 "| 2 | b | 1970-01-01T00:00:00.000000007 |",
643 "| 7 | c | 1970-01-01T00:00:00.000000006 |",
644 "| 9 | d | 1970-01-01T00:00:00.000000005 |",
645 "| 3 | e | 1970-01-01T00:00:00.000000008 |",
646 "| 10 | f | 1970-01-01T00:00:00.000000004 |",
647 "| 20 | g | 1970-01-01T00:00:00.000000006 |",
648 "| 70 | h | 1970-01-01T00:00:00.000000002 |",
649 "| 90 | i | 1970-01-01T00:00:00.000000002 |",
650 "| 30 | j | 1970-01-01T00:00:00.000000006 |",
651 "+----+---+-------------------------------+",
652 ],
653 task_ctx,
654 )
655 .await;
656 }
657
658 #[tokio::test]
659 async fn test_merge_three_partitions() {
660 let task_ctx = Arc::new(TaskContext::default());
661 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
662 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
663 Some("a"),
664 Some("b"),
665 Some("c"),
666 Some("d"),
667 Some("f"),
668 ]));
669 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
670 let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
671
672 let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 70, 90, 30]));
673 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
674 Some("e"),
675 Some("g"),
676 Some("h"),
677 Some("i"),
678 Some("j"),
679 ]));
680 let c: ArrayRef =
681 Arc::new(TimestampNanosecondArray::from(vec![40, 60, 20, 20, 60]));
682 let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
683
684 let a: ArrayRef = Arc::new(Int32Array::from(vec![100, 200, 700, 900, 300]));
685 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
686 Some("f"),
687 Some("g"),
688 Some("h"),
689 Some("i"),
690 Some("j"),
691 ]));
692 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
693 let b3 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
694
695 _test_merge(
696 &[vec![b1], vec![b2], vec![b3]],
697 &[
698 "+-----+---+-------------------------------+",
699 "| a | b | c |",
700 "+-----+---+-------------------------------+",
701 "| 1 | a | 1970-01-01T00:00:00.000000008 |",
702 "| 2 | b | 1970-01-01T00:00:00.000000007 |",
703 "| 7 | c | 1970-01-01T00:00:00.000000006 |",
704 "| 9 | d | 1970-01-01T00:00:00.000000005 |",
705 "| 10 | e | 1970-01-01T00:00:00.000000040 |",
706 "| 100 | f | 1970-01-01T00:00:00.000000004 |",
707 "| 3 | f | 1970-01-01T00:00:00.000000008 |",
708 "| 200 | g | 1970-01-01T00:00:00.000000006 |",
709 "| 20 | g | 1970-01-01T00:00:00.000000060 |",
710 "| 700 | h | 1970-01-01T00:00:00.000000002 |",
711 "| 70 | h | 1970-01-01T00:00:00.000000020 |",
712 "| 900 | i | 1970-01-01T00:00:00.000000002 |",
713 "| 90 | i | 1970-01-01T00:00:00.000000020 |",
714 "| 300 | j | 1970-01-01T00:00:00.000000006 |",
715 "| 30 | j | 1970-01-01T00:00:00.000000060 |",
716 "+-----+---+-------------------------------+",
717 ],
718 task_ctx,
719 )
720 .await;
721 }
722
723 async fn _test_merge(
724 partitions: &[Vec<RecordBatch>],
725 exp: &[&str],
726 context: Arc<TaskContext>,
727 ) {
728 let schema = partitions[0][0].schema();
729 let sort = LexOrdering::new(vec![
730 PhysicalSortExpr {
731 expr: col("b", &schema).unwrap(),
732 options: Default::default(),
733 },
734 PhysicalSortExpr {
735 expr: col("c", &schema).unwrap(),
736 options: Default::default(),
737 },
738 ]);
739 let exec = TestMemoryExec::try_new_exec(partitions, schema, None).unwrap();
740 let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
741
742 let collected = collect(merge, context).await.unwrap();
743 assert_batches_eq!(exp, collected.as_slice());
744 }
745
746 async fn sorted_merge(
747 input: Arc<dyn ExecutionPlan>,
748 sort: LexOrdering,
749 context: Arc<TaskContext>,
750 ) -> RecordBatch {
751 let merge = Arc::new(SortPreservingMergeExec::new(sort, input));
752 let mut result = collect(merge, context).await.unwrap();
753 assert_eq!(result.len(), 1);
754 result.remove(0)
755 }
756
757 async fn partition_sort(
758 input: Arc<dyn ExecutionPlan>,
759 sort: LexOrdering,
760 context: Arc<TaskContext>,
761 ) -> RecordBatch {
762 let sort_exec =
763 Arc::new(SortExec::new(sort.clone(), input).with_preserve_partitioning(true));
764 sorted_merge(sort_exec, sort, context).await
765 }
766
767 async fn basic_sort(
768 src: Arc<dyn ExecutionPlan>,
769 sort: LexOrdering,
770 context: Arc<TaskContext>,
771 ) -> RecordBatch {
772 let merge = Arc::new(CoalescePartitionsExec::new(src));
773 let sort_exec = Arc::new(SortExec::new(sort, merge));
774 let mut result = collect(sort_exec, context).await.unwrap();
775 assert_eq!(result.len(), 1);
776 result.remove(0)
777 }
778
779 #[tokio::test]
780 async fn test_partition_sort() -> Result<()> {
781 let task_ctx = Arc::new(TaskContext::default());
782 let partitions = 4;
783 let csv = test::scan_partitioned(partitions);
784 let schema = csv.schema();
785
786 let sort = LexOrdering::new(vec![PhysicalSortExpr {
787 expr: col("i", &schema).unwrap(),
788 options: SortOptions {
789 descending: true,
790 nulls_first: true,
791 },
792 }]);
793
794 let basic =
795 basic_sort(Arc::clone(&csv), sort.clone(), Arc::clone(&task_ctx)).await;
796 let partition = partition_sort(csv, sort, Arc::clone(&task_ctx)).await;
797
798 let basic = arrow::util::pretty::pretty_format_batches(&[basic])
799 .unwrap()
800 .to_string();
801 let partition = arrow::util::pretty::pretty_format_batches(&[partition])
802 .unwrap()
803 .to_string();
804
805 assert_eq!(
806 basic, partition,
807 "basic:\n\n{basic}\n\npartition:\n\n{partition}\n\n"
808 );
809
810 Ok(())
811 }
812
813 fn split_batch(sorted: &RecordBatch, batch_size: usize) -> Vec<RecordBatch> {
815 let batches = sorted.num_rows().div_ceil(batch_size);
816
817 (0..batches)
819 .map(|batch_idx| {
820 let columns = (0..sorted.num_columns())
821 .map(|column_idx| {
822 let length =
823 batch_size.min(sorted.num_rows() - batch_idx * batch_size);
824
825 sorted
826 .column(column_idx)
827 .slice(batch_idx * batch_size, length)
828 })
829 .collect();
830
831 RecordBatch::try_new(sorted.schema(), columns).unwrap()
832 })
833 .collect()
834 }
835
836 async fn sorted_partitioned_input(
837 sort: LexOrdering,
838 sizes: &[usize],
839 context: Arc<TaskContext>,
840 ) -> Result<Arc<dyn ExecutionPlan>> {
841 let partitions = 4;
842 let csv = test::scan_partitioned(partitions);
843
844 let sorted = basic_sort(csv, sort, context).await;
845 let split: Vec<_> = sizes.iter().map(|x| split_batch(&sorted, *x)).collect();
846
847 Ok(TestMemoryExec::try_new_exec(&split, sorted.schema(), None).unwrap())
848 }
849
850 #[tokio::test]
851 async fn test_partition_sort_streaming_input() -> Result<()> {
852 let task_ctx = Arc::new(TaskContext::default());
853 let schema = make_partition(11).schema();
854 let sort = LexOrdering::new(vec![PhysicalSortExpr {
855 expr: col("i", &schema).unwrap(),
856 options: Default::default(),
857 }]);
858
859 let input =
860 sorted_partitioned_input(sort.clone(), &[10, 3, 11], Arc::clone(&task_ctx))
861 .await?;
862 let basic =
863 basic_sort(Arc::clone(&input), sort.clone(), Arc::clone(&task_ctx)).await;
864 let partition = sorted_merge(input, sort, Arc::clone(&task_ctx)).await;
865
866 assert_eq!(basic.num_rows(), 1200);
867 assert_eq!(partition.num_rows(), 1200);
868
869 let basic = arrow::util::pretty::pretty_format_batches(&[basic])
870 .unwrap()
871 .to_string();
872 let partition = arrow::util::pretty::pretty_format_batches(&[partition])
873 .unwrap()
874 .to_string();
875
876 assert_eq!(basic, partition);
877
878 Ok(())
879 }
880
881 #[tokio::test]
882 async fn test_partition_sort_streaming_input_output() -> Result<()> {
883 let schema = make_partition(11).schema();
884 let sort = LexOrdering::new(vec![PhysicalSortExpr {
885 expr: col("i", &schema).unwrap(),
886 options: Default::default(),
887 }]);
888
889 let task_ctx = Arc::new(TaskContext::default());
891 let input =
892 sorted_partitioned_input(sort.clone(), &[10, 5, 13], Arc::clone(&task_ctx))
893 .await?;
894 let basic = basic_sort(Arc::clone(&input), sort.clone(), task_ctx).await;
895
896 let task_ctx = TaskContext::default()
898 .with_session_config(SessionConfig::new().with_batch_size(23));
899 let task_ctx = Arc::new(task_ctx);
900
901 let merge = Arc::new(SortPreservingMergeExec::new(sort, input));
902 let merged = collect(merge, task_ctx).await.unwrap();
903
904 assert_eq!(merged.len(), 53);
905
906 assert_eq!(basic.num_rows(), 1200);
907 assert_eq!(merged.iter().map(|x| x.num_rows()).sum::<usize>(), 1200);
908
909 let basic = arrow::util::pretty::pretty_format_batches(&[basic])
910 .unwrap()
911 .to_string();
912 let partition = arrow::util::pretty::pretty_format_batches(merged.as_slice())
913 .unwrap()
914 .to_string();
915
916 assert_eq!(basic, partition);
917
918 Ok(())
919 }
920
921 #[tokio::test]
922 async fn test_nulls() {
923 let task_ctx = Arc::new(TaskContext::default());
924 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
925 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
926 None,
927 Some("a"),
928 Some("b"),
929 Some("d"),
930 Some("e"),
931 ]));
932 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![
933 Some(8),
934 None,
935 Some(6),
936 None,
937 Some(4),
938 ]));
939 let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
940
941 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
942 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
943 None,
944 Some("b"),
945 Some("g"),
946 Some("h"),
947 Some("i"),
948 ]));
949 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![
950 Some(8),
951 None,
952 Some(5),
953 None,
954 Some(4),
955 ]));
956 let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
957 let schema = b1.schema();
958
959 let sort = LexOrdering::new(vec![
960 PhysicalSortExpr {
961 expr: col("b", &schema).unwrap(),
962 options: SortOptions {
963 descending: false,
964 nulls_first: true,
965 },
966 },
967 PhysicalSortExpr {
968 expr: col("c", &schema).unwrap(),
969 options: SortOptions {
970 descending: false,
971 nulls_first: false,
972 },
973 },
974 ]);
975 let exec =
976 TestMemoryExec::try_new_exec(&[vec![b1], vec![b2]], schema, None).unwrap();
977 let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
978
979 let collected = collect(merge, task_ctx).await.unwrap();
980 assert_eq!(collected.len(), 1);
981
982 assert_batches_eq!(
983 &[
984 "+---+---+-------------------------------+",
985 "| a | b | c |",
986 "+---+---+-------------------------------+",
987 "| 1 | | 1970-01-01T00:00:00.000000008 |",
988 "| 1 | | 1970-01-01T00:00:00.000000008 |",
989 "| 2 | a | |",
990 "| 7 | b | 1970-01-01T00:00:00.000000006 |",
991 "| 2 | b | |",
992 "| 9 | d | |",
993 "| 3 | e | 1970-01-01T00:00:00.000000004 |",
994 "| 3 | g | 1970-01-01T00:00:00.000000005 |",
995 "| 4 | h | |",
996 "| 5 | i | 1970-01-01T00:00:00.000000004 |",
997 "+---+---+-------------------------------+",
998 ],
999 collected.as_slice()
1000 );
1001 }
1002
1003 #[tokio::test]
1004 async fn test_sort_merge_single_partition_with_fetch() {
1005 let task_ctx = Arc::new(TaskContext::default());
1006 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
1007 let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
1008 let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
1009 let schema = batch.schema();
1010
1011 let sort = LexOrdering::new(vec![PhysicalSortExpr {
1012 expr: col("b", &schema).unwrap(),
1013 options: SortOptions {
1014 descending: false,
1015 nulls_first: true,
1016 },
1017 }]);
1018 let exec = TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap();
1019 let merge =
1020 Arc::new(SortPreservingMergeExec::new(sort, exec).with_fetch(Some(2)));
1021
1022 let collected = collect(merge, task_ctx).await.unwrap();
1023 assert_eq!(collected.len(), 1);
1024
1025 assert_batches_eq!(
1026 &[
1027 "+---+---+",
1028 "| a | b |",
1029 "+---+---+",
1030 "| 1 | a |",
1031 "| 2 | b |",
1032 "+---+---+",
1033 ],
1034 collected.as_slice()
1035 );
1036 }
1037
1038 #[tokio::test]
1039 async fn test_sort_merge_single_partition_without_fetch() {
1040 let task_ctx = Arc::new(TaskContext::default());
1041 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
1042 let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
1043 let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
1044 let schema = batch.schema();
1045
1046 let sort = LexOrdering::new(vec![PhysicalSortExpr {
1047 expr: col("b", &schema).unwrap(),
1048 options: SortOptions {
1049 descending: false,
1050 nulls_first: true,
1051 },
1052 }]);
1053 let exec = TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap();
1054 let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
1055
1056 let collected = collect(merge, task_ctx).await.unwrap();
1057 assert_eq!(collected.len(), 1);
1058
1059 assert_batches_eq!(
1060 &[
1061 "+---+---+",
1062 "| a | b |",
1063 "+---+---+",
1064 "| 1 | a |",
1065 "| 2 | b |",
1066 "| 7 | c |",
1067 "| 9 | d |",
1068 "| 3 | e |",
1069 "+---+---+",
1070 ],
1071 collected.as_slice()
1072 );
1073 }
1074
1075 #[tokio::test]
1076 async fn test_async() -> Result<()> {
1077 let task_ctx = Arc::new(TaskContext::default());
1078 let schema = make_partition(11).schema();
1079 let sort = LexOrdering::new(vec![PhysicalSortExpr {
1080 expr: col("i", &schema).unwrap(),
1081 options: SortOptions::default(),
1082 }]);
1083
1084 let batches =
1085 sorted_partitioned_input(sort.clone(), &[5, 7, 3], Arc::clone(&task_ctx))
1086 .await?;
1087
1088 let partition_count = batches.output_partitioning().partition_count();
1089 let mut streams = Vec::with_capacity(partition_count);
1090
1091 for partition in 0..partition_count {
1092 let mut builder = RecordBatchReceiverStream::builder(Arc::clone(&schema), 1);
1093
1094 let sender = builder.tx();
1095
1096 let mut stream = batches.execute(partition, Arc::clone(&task_ctx)).unwrap();
1097 builder.spawn(async move {
1098 while let Some(batch) = stream.next().await {
1099 sender.send(batch).await.unwrap();
1100 tokio::time::sleep(Duration::from_millis(10)).await;
1102 }
1103
1104 Ok(())
1105 });
1106
1107 streams.push(builder.build());
1108 }
1109
1110 let metrics = ExecutionPlanMetricsSet::new();
1111 let reservation =
1112 MemoryConsumer::new("test").register(&task_ctx.runtime_env().memory_pool);
1113
1114 let fetch = None;
1115 let merge_stream = StreamingMergeBuilder::new()
1116 .with_streams(streams)
1117 .with_schema(batches.schema())
1118 .with_expressions(sort.as_ref())
1119 .with_metrics(BaselineMetrics::new(&metrics, 0))
1120 .with_batch_size(task_ctx.session_config().batch_size())
1121 .with_fetch(fetch)
1122 .with_reservation(reservation)
1123 .build()?;
1124
1125 let mut merged = common::collect(merge_stream).await.unwrap();
1126
1127 assert_eq!(merged.len(), 1);
1128 let merged = merged.remove(0);
1129 let basic = basic_sort(batches, sort.clone(), Arc::clone(&task_ctx)).await;
1130
1131 let basic = arrow::util::pretty::pretty_format_batches(&[basic])
1132 .unwrap()
1133 .to_string();
1134 let partition = arrow::util::pretty::pretty_format_batches(&[merged])
1135 .unwrap()
1136 .to_string();
1137
1138 assert_eq!(
1139 basic, partition,
1140 "basic:\n\n{basic}\n\npartition:\n\n{partition}\n\n"
1141 );
1142
1143 Ok(())
1144 }
1145
1146 #[tokio::test]
1147 async fn test_merge_metrics() {
1148 let task_ctx = Arc::new(TaskContext::default());
1149 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
1150 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("a"), Some("c")]));
1151 let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
1152
1153 let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20]));
1154 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("b"), Some("d")]));
1155 let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
1156
1157 let schema = b1.schema();
1158 let sort = LexOrdering::new(vec![PhysicalSortExpr {
1159 expr: col("b", &schema).unwrap(),
1160 options: Default::default(),
1161 }]);
1162 let exec =
1163 TestMemoryExec::try_new_exec(&[vec![b1], vec![b2]], schema, None).unwrap();
1164 let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
1165
1166 let collected = collect(Arc::clone(&merge) as Arc<dyn ExecutionPlan>, task_ctx)
1167 .await
1168 .unwrap();
1169 let expected = [
1170 "+----+---+",
1171 "| a | b |",
1172 "+----+---+",
1173 "| 1 | a |",
1174 "| 10 | b |",
1175 "| 2 | c |",
1176 "| 20 | d |",
1177 "+----+---+",
1178 ];
1179 assert_batches_eq!(expected, collected.as_slice());
1180
1181 let metrics = merge.metrics().unwrap();
1183
1184 assert_eq!(metrics.output_rows().unwrap(), 4);
1185 assert!(metrics.elapsed_compute().unwrap() > 0);
1186
1187 let mut saw_start = false;
1188 let mut saw_end = false;
1189 metrics.iter().for_each(|m| match m.value() {
1190 MetricValue::StartTimestamp(ts) => {
1191 saw_start = true;
1192 assert!(nanos_from_timestamp(ts) > 0);
1193 }
1194 MetricValue::EndTimestamp(ts) => {
1195 saw_end = true;
1196 assert!(nanos_from_timestamp(ts) > 0);
1197 }
1198 _ => {}
1199 });
1200
1201 assert!(saw_start);
1202 assert!(saw_end);
1203 }
1204
1205 fn nanos_from_timestamp(ts: &Timestamp) -> i64 {
1206 ts.value().unwrap().timestamp_nanos_opt().unwrap()
1207 }
1208
1209 #[tokio::test]
1210 async fn test_drop_cancel() -> Result<()> {
1211 let task_ctx = Arc::new(TaskContext::default());
1212 let schema =
1213 Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]));
1214
1215 let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 2));
1216 let refs = blocking_exec.refs();
1217 let sort_preserving_merge_exec = Arc::new(SortPreservingMergeExec::new(
1218 LexOrdering::new(vec![PhysicalSortExpr {
1219 expr: col("a", &schema)?,
1220 options: SortOptions::default(),
1221 }]),
1222 blocking_exec,
1223 ));
1224
1225 let fut = collect(sort_preserving_merge_exec, task_ctx);
1226 let mut fut = fut.boxed();
1227
1228 assert_is_pending(&mut fut);
1229 drop(fut);
1230 assert_strong_count_converges_to_zero(refs).await;
1231
1232 Ok(())
1233 }
1234
1235 #[tokio::test]
1236 async fn test_stable_sort() {
1237 let task_ctx = Arc::new(TaskContext::default());
1238
1239 let partitions: Vec<Vec<RecordBatch>> = (0..10)
1247 .map(|batch_number| {
1248 let batch_number: Int32Array =
1249 vec![Some(batch_number), Some(batch_number)]
1250 .into_iter()
1251 .collect();
1252 let value: StringArray = vec![Some("A"), Some("B")].into_iter().collect();
1253
1254 let batch = RecordBatch::try_from_iter(vec![
1255 ("batch_number", Arc::new(batch_number) as ArrayRef),
1256 ("value", Arc::new(value) as ArrayRef),
1257 ])
1258 .unwrap();
1259
1260 vec![batch]
1261 })
1262 .collect();
1263
1264 let schema = partitions[0][0].schema();
1265
1266 let sort = LexOrdering::new(vec![PhysicalSortExpr {
1267 expr: col("value", &schema).unwrap(),
1268 options: SortOptions {
1269 descending: false,
1270 nulls_first: true,
1271 },
1272 }]);
1273
1274 let exec = TestMemoryExec::try_new_exec(&partitions, schema, None).unwrap();
1275 let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
1276
1277 let collected = collect(merge, task_ctx).await.unwrap();
1278 assert_eq!(collected.len(), 1);
1279
1280 assert_batches_eq!(
1284 &[
1285 "+--------------+-------+",
1286 "| batch_number | value |",
1287 "+--------------+-------+",
1288 "| 0 | A |",
1289 "| 1 | A |",
1290 "| 2 | A |",
1291 "| 3 | A |",
1292 "| 4 | A |",
1293 "| 5 | A |",
1294 "| 6 | A |",
1295 "| 7 | A |",
1296 "| 8 | A |",
1297 "| 9 | A |",
1298 "| 0 | B |",
1299 "| 1 | B |",
1300 "| 2 | B |",
1301 "| 3 | B |",
1302 "| 4 | B |",
1303 "| 5 | B |",
1304 "| 6 | B |",
1305 "| 7 | B |",
1306 "| 8 | B |",
1307 "| 9 | B |",
1308 "+--------------+-------+",
1309 ],
1310 collected.as_slice()
1311 );
1312 }
1313
1314 #[derive(Debug, Clone)]
1317 struct CongestedExec {
1318 schema: Schema,
1319 cache: PlanProperties,
1320 congestion_cleared: Arc<Mutex<bool>>,
1321 }
1322
1323 impl CongestedExec {
1324 fn compute_properties(schema: SchemaRef) -> PlanProperties {
1325 let columns = schema
1326 .fields
1327 .iter()
1328 .enumerate()
1329 .map(|(i, f)| Arc::new(Column::new(f.name(), i)) as Arc<dyn PhysicalExpr>)
1330 .collect::<Vec<_>>();
1331 let mut eq_properties = EquivalenceProperties::new(schema);
1332 eq_properties.add_new_orderings(vec![columns
1333 .iter()
1334 .map(|expr| PhysicalSortExpr::new_default(Arc::clone(expr)))
1335 .collect::<LexOrdering>()]);
1336 PlanProperties::new(
1337 eq_properties,
1338 Partitioning::Hash(columns, 3),
1339 EmissionType::Incremental,
1340 Boundedness::Unbounded {
1341 requires_infinite_memory: false,
1342 },
1343 )
1344 }
1345 }
1346
1347 impl ExecutionPlan for CongestedExec {
1348 fn name(&self) -> &'static str {
1349 Self::static_name()
1350 }
1351 fn as_any(&self) -> &dyn Any {
1352 self
1353 }
1354 fn properties(&self) -> &PlanProperties {
1355 &self.cache
1356 }
1357 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1358 vec![]
1359 }
1360 fn with_new_children(
1361 self: Arc<Self>,
1362 _: Vec<Arc<dyn ExecutionPlan>>,
1363 ) -> Result<Arc<dyn ExecutionPlan>> {
1364 Ok(self)
1365 }
1366 fn execute(
1367 &self,
1368 partition: usize,
1369 _context: Arc<TaskContext>,
1370 ) -> Result<SendableRecordBatchStream> {
1371 Ok(Box::pin(CongestedStream {
1372 schema: Arc::new(self.schema.clone()),
1373 none_polled_once: false,
1374 congestion_cleared: Arc::clone(&self.congestion_cleared),
1375 partition,
1376 }))
1377 }
1378 }
1379
1380 impl DisplayAs for CongestedExec {
1381 fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
1382 match t {
1383 DisplayFormatType::Default | DisplayFormatType::Verbose => {
1384 write!(f, "CongestedExec",).unwrap()
1385 }
1386 }
1387 Ok(())
1388 }
1389 }
1390
1391 #[derive(Debug)]
1394 pub struct CongestedStream {
1395 schema: SchemaRef,
1396 none_polled_once: bool,
1397 congestion_cleared: Arc<Mutex<bool>>,
1398 partition: usize,
1399 }
1400
1401 impl Stream for CongestedStream {
1402 type Item = Result<RecordBatch>;
1403 fn poll_next(
1404 mut self: Pin<&mut Self>,
1405 _cx: &mut Context<'_>,
1406 ) -> Poll<Option<Self::Item>> {
1407 match self.partition {
1408 0 => {
1409 if self.none_polled_once {
1410 panic!("Exhausted stream is polled more than one")
1411 } else {
1412 self.none_polled_once = true;
1413 Poll::Ready(None)
1414 }
1415 }
1416 1 => {
1417 let cleared = self.congestion_cleared.lock().unwrap();
1418 if *cleared {
1419 Poll::Ready(None)
1420 } else {
1421 Poll::Pending
1422 }
1423 }
1424 2 => {
1425 let mut cleared = self.congestion_cleared.lock().unwrap();
1426 *cleared = true;
1427 Poll::Ready(None)
1428 }
1429 _ => unreachable!(),
1430 }
1431 }
1432 }
1433
1434 impl RecordBatchStream for CongestedStream {
1435 fn schema(&self) -> SchemaRef {
1436 Arc::clone(&self.schema)
1437 }
1438 }
1439
1440 #[tokio::test]
1441 async fn test_spm_congestion() -> Result<()> {
1442 let task_ctx = Arc::new(TaskContext::default());
1443 let schema = Schema::new(vec![Field::new("c1", DataType::UInt64, false)]);
1444 let source = CongestedExec {
1445 schema: schema.clone(),
1446 cache: CongestedExec::compute_properties(Arc::new(schema.clone())),
1447 congestion_cleared: Arc::new(Mutex::new(false)),
1448 };
1449 let spm = SortPreservingMergeExec::new(
1450 LexOrdering::new(vec![PhysicalSortExpr::new_default(Arc::new(Column::new(
1451 "c1", 0,
1452 )))]),
1453 Arc::new(source),
1454 );
1455 let spm_task = SpawnedTask::spawn(collect(Arc::new(spm), task_ctx));
1456
1457 let result = timeout(Duration::from_secs(3), spm_task.join()).await;
1458 match result {
1459 Ok(Ok(Ok(_batches))) => Ok(()),
1460 Ok(Ok(Err(e))) => Err(e),
1461 Ok(Err(_)) => Err(DataFusionError::Execution(
1462 "SortPreservingMerge task panicked or was cancelled".to_string(),
1463 )),
1464 Err(_) => Err(DataFusionError::Execution(
1465 "SortPreservingMerge caused a deadlock".to_string(),
1466 )),
1467 }
1468 }
1469}