1use std::any::Any;
21use std::pin::Pin;
22use std::sync::Arc;
23use std::task::{Context, Poll};
24
25use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
26use super::{
27 DisplayAs, ExecutionPlanProperties, PlanProperties, RecordBatchStream,
28 SendableRecordBatchStream, Statistics,
29};
30use crate::execution_plan::{Boundedness, CardinalityEffect};
31use crate::{DisplayFormatType, Distribution, ExecutionPlan, Partitioning};
32
33use arrow::datatypes::SchemaRef;
34use arrow::record_batch::RecordBatch;
35use datafusion_common::{internal_err, Result};
36use datafusion_execution::TaskContext;
37
38use futures::stream::{Stream, StreamExt};
39use log::trace;
40
41#[derive(Debug, Clone)]
43pub struct GlobalLimitExec {
44 input: Arc<dyn ExecutionPlan>,
46 skip: usize,
48 fetch: Option<usize>,
51 metrics: ExecutionPlanMetricsSet,
53 cache: PlanProperties,
54}
55
56impl GlobalLimitExec {
57 pub fn new(input: Arc<dyn ExecutionPlan>, skip: usize, fetch: Option<usize>) -> Self {
59 let cache = Self::compute_properties(&input);
60 GlobalLimitExec {
61 input,
62 skip,
63 fetch,
64 metrics: ExecutionPlanMetricsSet::new(),
65 cache,
66 }
67 }
68
69 pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
71 &self.input
72 }
73
74 pub fn skip(&self) -> usize {
76 self.skip
77 }
78
79 pub fn fetch(&self) -> Option<usize> {
81 self.fetch
82 }
83
84 fn compute_properties(input: &Arc<dyn ExecutionPlan>) -> PlanProperties {
86 PlanProperties::new(
87 input.equivalence_properties().clone(), Partitioning::UnknownPartitioning(1), input.pipeline_behavior(),
90 Boundedness::Bounded,
92 )
93 }
94}
95
96impl DisplayAs for GlobalLimitExec {
97 fn fmt_as(
98 &self,
99 t: DisplayFormatType,
100 f: &mut std::fmt::Formatter,
101 ) -> std::fmt::Result {
102 match t {
103 DisplayFormatType::Default | DisplayFormatType::Verbose => {
104 write!(
105 f,
106 "GlobalLimitExec: skip={}, fetch={}",
107 self.skip,
108 self.fetch.map_or("None".to_string(), |x| x.to_string())
109 )
110 }
111 }
112 }
113}
114
115impl ExecutionPlan for GlobalLimitExec {
116 fn name(&self) -> &'static str {
117 "GlobalLimitExec"
118 }
119
120 fn as_any(&self) -> &dyn Any {
122 self
123 }
124
125 fn properties(&self) -> &PlanProperties {
126 &self.cache
127 }
128
129 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
130 vec![&self.input]
131 }
132
133 fn required_input_distribution(&self) -> Vec<Distribution> {
134 vec![Distribution::SinglePartition]
135 }
136
137 fn maintains_input_order(&self) -> Vec<bool> {
138 vec![true]
139 }
140
141 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
142 vec![false]
143 }
144
145 fn with_new_children(
146 self: Arc<Self>,
147 children: Vec<Arc<dyn ExecutionPlan>>,
148 ) -> Result<Arc<dyn ExecutionPlan>> {
149 Ok(Arc::new(GlobalLimitExec::new(
150 Arc::clone(&children[0]),
151 self.skip,
152 self.fetch,
153 )))
154 }
155
156 fn execute(
157 &self,
158 partition: usize,
159 context: Arc<TaskContext>,
160 ) -> Result<SendableRecordBatchStream> {
161 trace!(
162 "Start GlobalLimitExec::execute for partition: {}",
163 partition
164 );
165 if 0 != partition {
167 return internal_err!("GlobalLimitExec invalid partition {partition}");
168 }
169
170 if 1 != self.input.output_partitioning().partition_count() {
172 return internal_err!("GlobalLimitExec requires a single input partition");
173 }
174
175 let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
176 let stream = self.input.execute(0, context)?;
177 Ok(Box::pin(LimitStream::new(
178 stream,
179 self.skip,
180 self.fetch,
181 baseline_metrics,
182 )))
183 }
184
185 fn metrics(&self) -> Option<MetricsSet> {
186 Some(self.metrics.clone_inner())
187 }
188
189 fn statistics(&self) -> Result<Statistics> {
190 Statistics::with_fetch(
191 self.input.statistics()?,
192 self.schema(),
193 self.fetch,
194 self.skip,
195 1,
196 )
197 }
198
199 fn fetch(&self) -> Option<usize> {
200 self.fetch
201 }
202
203 fn supports_limit_pushdown(&self) -> bool {
204 true
205 }
206}
207
208#[derive(Debug)]
210pub struct LocalLimitExec {
211 input: Arc<dyn ExecutionPlan>,
213 fetch: usize,
215 metrics: ExecutionPlanMetricsSet,
217 cache: PlanProperties,
218}
219
220impl LocalLimitExec {
221 pub fn new(input: Arc<dyn ExecutionPlan>, fetch: usize) -> Self {
223 let cache = Self::compute_properties(&input);
224 Self {
225 input,
226 fetch,
227 metrics: ExecutionPlanMetricsSet::new(),
228 cache,
229 }
230 }
231
232 pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
234 &self.input
235 }
236
237 pub fn fetch(&self) -> usize {
239 self.fetch
240 }
241
242 fn compute_properties(input: &Arc<dyn ExecutionPlan>) -> PlanProperties {
244 PlanProperties::new(
245 input.equivalence_properties().clone(), input.output_partitioning().clone(), input.pipeline_behavior(),
248 Boundedness::Bounded,
250 )
251 }
252}
253
254impl DisplayAs for LocalLimitExec {
255 fn fmt_as(
256 &self,
257 t: DisplayFormatType,
258 f: &mut std::fmt::Formatter,
259 ) -> std::fmt::Result {
260 match t {
261 DisplayFormatType::Default | DisplayFormatType::Verbose => {
262 write!(f, "LocalLimitExec: fetch={}", self.fetch)
263 }
264 }
265 }
266}
267
268impl ExecutionPlan for LocalLimitExec {
269 fn name(&self) -> &'static str {
270 "LocalLimitExec"
271 }
272
273 fn as_any(&self) -> &dyn Any {
275 self
276 }
277
278 fn properties(&self) -> &PlanProperties {
279 &self.cache
280 }
281
282 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
283 vec![&self.input]
284 }
285
286 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
287 vec![false]
288 }
289
290 fn maintains_input_order(&self) -> Vec<bool> {
291 vec![true]
292 }
293
294 fn with_new_children(
295 self: Arc<Self>,
296 children: Vec<Arc<dyn ExecutionPlan>>,
297 ) -> Result<Arc<dyn ExecutionPlan>> {
298 match children.len() {
299 1 => Ok(Arc::new(LocalLimitExec::new(
300 Arc::clone(&children[0]),
301 self.fetch,
302 ))),
303 _ => internal_err!("LocalLimitExec wrong number of children"),
304 }
305 }
306
307 fn execute(
308 &self,
309 partition: usize,
310 context: Arc<TaskContext>,
311 ) -> Result<SendableRecordBatchStream> {
312 trace!("Start LocalLimitExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id());
313 let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
314 let stream = self.input.execute(partition, context)?;
315 Ok(Box::pin(LimitStream::new(
316 stream,
317 0,
318 Some(self.fetch),
319 baseline_metrics,
320 )))
321 }
322
323 fn metrics(&self) -> Option<MetricsSet> {
324 Some(self.metrics.clone_inner())
325 }
326
327 fn statistics(&self) -> Result<Statistics> {
328 Statistics::with_fetch(
329 self.input.statistics()?,
330 self.schema(),
331 Some(self.fetch),
332 0,
333 1,
334 )
335 }
336
337 fn fetch(&self) -> Option<usize> {
338 Some(self.fetch)
339 }
340
341 fn supports_limit_pushdown(&self) -> bool {
342 true
343 }
344
345 fn cardinality_effect(&self) -> CardinalityEffect {
346 CardinalityEffect::LowerEqual
347 }
348}
349
350pub struct LimitStream {
352 skip: usize,
354 fetch: usize,
356 input: Option<SendableRecordBatchStream>,
359 schema: SchemaRef,
361 baseline_metrics: BaselineMetrics,
363}
364
365impl LimitStream {
366 pub fn new(
367 input: SendableRecordBatchStream,
368 skip: usize,
369 fetch: Option<usize>,
370 baseline_metrics: BaselineMetrics,
371 ) -> Self {
372 let schema = input.schema();
373 Self {
374 skip,
375 fetch: fetch.unwrap_or(usize::MAX),
376 input: Some(input),
377 schema,
378 baseline_metrics,
379 }
380 }
381
382 fn poll_and_skip(
383 &mut self,
384 cx: &mut Context<'_>,
385 ) -> Poll<Option<Result<RecordBatch>>> {
386 let input = self.input.as_mut().unwrap();
387 loop {
388 let poll = input.poll_next_unpin(cx);
389 let poll = poll.map_ok(|batch| {
390 if batch.num_rows() <= self.skip {
391 self.skip -= batch.num_rows();
392 RecordBatch::new_empty(input.schema())
393 } else {
394 let new_batch = batch.slice(self.skip, batch.num_rows() - self.skip);
395 self.skip = 0;
396 new_batch
397 }
398 });
399
400 match &poll {
401 Poll::Ready(Some(Ok(batch))) => {
402 if batch.num_rows() > 0 {
403 break poll;
404 } else {
405 }
407 }
408 Poll::Ready(Some(Err(_e))) => break poll,
409 Poll::Ready(None) => break poll,
410 Poll::Pending => break poll,
411 }
412 }
413 }
414
415 fn stream_limit(&mut self, batch: RecordBatch) -> Option<RecordBatch> {
417 let _timer = self.baseline_metrics.elapsed_compute().timer();
419 if self.fetch == 0 {
420 self.input = None; None
422 } else if batch.num_rows() < self.fetch {
423 self.fetch -= batch.num_rows();
425 Some(batch)
426 } else if batch.num_rows() >= self.fetch {
427 let batch_rows = self.fetch;
428 self.fetch = 0;
429 self.input = None; Some(batch.slice(0, batch_rows))
433 } else {
434 unreachable!()
435 }
436 }
437}
438
439impl Stream for LimitStream {
440 type Item = Result<RecordBatch>;
441
442 fn poll_next(
443 mut self: Pin<&mut Self>,
444 cx: &mut Context<'_>,
445 ) -> Poll<Option<Self::Item>> {
446 let fetch_started = self.skip == 0;
447 let poll = match &mut self.input {
448 Some(input) => {
449 let poll = if fetch_started {
450 input.poll_next_unpin(cx)
451 } else {
452 self.poll_and_skip(cx)
453 };
454
455 poll.map(|x| match x {
456 Some(Ok(batch)) => Ok(self.stream_limit(batch)).transpose(),
457 other => other,
458 })
459 }
460 None => Poll::Ready(None),
462 };
463
464 self.baseline_metrics.record_poll(poll)
465 }
466}
467
468impl RecordBatchStream for LimitStream {
469 fn schema(&self) -> SchemaRef {
471 Arc::clone(&self.schema)
472 }
473}
474
475#[cfg(test)]
476mod tests {
477 use super::*;
478 use crate::coalesce_partitions::CoalescePartitionsExec;
479 use crate::common::collect;
480 use crate::test;
481
482 use crate::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
483 use arrow::array::RecordBatchOptions;
484 use arrow::datatypes::Schema;
485 use datafusion_common::stats::Precision;
486 use datafusion_physical_expr::expressions::col;
487 use datafusion_physical_expr::PhysicalExpr;
488
489 #[tokio::test]
490 async fn limit() -> Result<()> {
491 let task_ctx = Arc::new(TaskContext::default());
492
493 let num_partitions = 4;
494 let csv = test::scan_partitioned(num_partitions);
495
496 assert_eq!(csv.output_partitioning().partition_count(), num_partitions);
498
499 let limit =
500 GlobalLimitExec::new(Arc::new(CoalescePartitionsExec::new(csv)), 0, Some(7));
501
502 let iter = limit.execute(0, task_ctx)?;
504 let batches = collect(iter).await?;
505
506 let row_count: usize = batches.iter().map(|batch| batch.num_rows()).sum();
508 assert_eq!(row_count, 7);
509
510 Ok(())
511 }
512
513 #[tokio::test]
514 async fn limit_early_shutdown() -> Result<()> {
515 let batches = vec![
516 test::make_partition(5),
517 test::make_partition(10),
518 test::make_partition(15),
519 test::make_partition(20),
520 test::make_partition(25),
521 ];
522 let input = test::exec::TestStream::new(batches);
523
524 let index = input.index();
525 assert_eq!(index.value(), 0);
526
527 let baseline_metrics = BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
530 let limit_stream =
531 LimitStream::new(Box::pin(input), 0, Some(6), baseline_metrics);
532 assert_eq!(index.value(), 0);
533
534 let results = collect(Box::pin(limit_stream)).await.unwrap();
535 let num_rows: usize = results.into_iter().map(|b| b.num_rows()).sum();
536 assert_eq!(num_rows, 6);
538
539 assert_eq!(index.value(), 2);
541
542 Ok(())
543 }
544
545 #[tokio::test]
546 async fn limit_equals_batch_size() -> Result<()> {
547 let batches = vec![
548 test::make_partition(6),
549 test::make_partition(6),
550 test::make_partition(6),
551 ];
552 let input = test::exec::TestStream::new(batches);
553
554 let index = input.index();
555 assert_eq!(index.value(), 0);
556
557 let baseline_metrics = BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
560 let limit_stream =
561 LimitStream::new(Box::pin(input), 0, Some(6), baseline_metrics);
562 assert_eq!(index.value(), 0);
563
564 let results = collect(Box::pin(limit_stream)).await.unwrap();
565 let num_rows: usize = results.into_iter().map(|b| b.num_rows()).sum();
566 assert_eq!(num_rows, 6);
568
569 assert_eq!(index.value(), 1);
571
572 Ok(())
573 }
574
575 #[tokio::test]
576 async fn limit_no_column() -> Result<()> {
577 let batches = vec![
578 make_batch_no_column(6),
579 make_batch_no_column(6),
580 make_batch_no_column(6),
581 ];
582 let input = test::exec::TestStream::new(batches);
583
584 let index = input.index();
585 assert_eq!(index.value(), 0);
586
587 let baseline_metrics = BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
590 let limit_stream =
591 LimitStream::new(Box::pin(input), 0, Some(6), baseline_metrics);
592 assert_eq!(index.value(), 0);
593
594 let results = collect(Box::pin(limit_stream)).await.unwrap();
595 let num_rows: usize = results.into_iter().map(|b| b.num_rows()).sum();
596 assert_eq!(num_rows, 6);
598
599 assert_eq!(index.value(), 1);
601
602 Ok(())
603 }
604
605 async fn skip_and_fetch(skip: usize, fetch: Option<usize>) -> Result<usize> {
607 let task_ctx = Arc::new(TaskContext::default());
608
609 let num_partitions = 4;
611 let csv = test::scan_partitioned(num_partitions);
612
613 assert_eq!(csv.output_partitioning().partition_count(), num_partitions);
614
615 let offset =
616 GlobalLimitExec::new(Arc::new(CoalescePartitionsExec::new(csv)), skip, fetch);
617
618 let iter = offset.execute(0, task_ctx)?;
620 let batches = collect(iter).await?;
621 Ok(batches.iter().map(|batch| batch.num_rows()).sum())
622 }
623
624 #[tokio::test]
625 async fn skip_none_fetch_none() -> Result<()> {
626 let row_count = skip_and_fetch(0, None).await?;
627 assert_eq!(row_count, 400);
628 Ok(())
629 }
630
631 #[tokio::test]
632 async fn skip_none_fetch_50() -> Result<()> {
633 let row_count = skip_and_fetch(0, Some(50)).await?;
634 assert_eq!(row_count, 50);
635 Ok(())
636 }
637
638 #[tokio::test]
639 async fn skip_3_fetch_none() -> Result<()> {
640 let row_count = skip_and_fetch(3, None).await?;
642 assert_eq!(row_count, 397);
643 Ok(())
644 }
645
646 #[tokio::test]
647 async fn skip_3_fetch_10_stats() -> Result<()> {
648 let row_count = skip_and_fetch(3, Some(10)).await?;
650 assert_eq!(row_count, 10);
651 Ok(())
652 }
653
654 #[tokio::test]
655 async fn skip_400_fetch_none() -> Result<()> {
656 let row_count = skip_and_fetch(400, None).await?;
657 assert_eq!(row_count, 0);
658 Ok(())
659 }
660
661 #[tokio::test]
662 async fn skip_400_fetch_1() -> Result<()> {
663 let row_count = skip_and_fetch(400, Some(1)).await?;
665 assert_eq!(row_count, 0);
666 Ok(())
667 }
668
669 #[tokio::test]
670 async fn skip_401_fetch_none() -> Result<()> {
671 let row_count = skip_and_fetch(401, None).await?;
673 assert_eq!(row_count, 0);
674 Ok(())
675 }
676
677 #[tokio::test]
678 async fn test_row_number_statistics_for_global_limit() -> Result<()> {
679 let row_count = row_number_statistics_for_global_limit(0, Some(10)).await?;
680 assert_eq!(row_count, Precision::Exact(10));
681
682 let row_count = row_number_statistics_for_global_limit(5, Some(10)).await?;
683 assert_eq!(row_count, Precision::Exact(10));
684
685 let row_count = row_number_statistics_for_global_limit(400, Some(10)).await?;
686 assert_eq!(row_count, Precision::Exact(0));
687
688 let row_count = row_number_statistics_for_global_limit(398, Some(10)).await?;
689 assert_eq!(row_count, Precision::Exact(2));
690
691 let row_count = row_number_statistics_for_global_limit(398, Some(1)).await?;
692 assert_eq!(row_count, Precision::Exact(1));
693
694 let row_count = row_number_statistics_for_global_limit(398, None).await?;
695 assert_eq!(row_count, Precision::Exact(2));
696
697 let row_count =
698 row_number_statistics_for_global_limit(0, Some(usize::MAX)).await?;
699 assert_eq!(row_count, Precision::Exact(400));
700
701 let row_count =
702 row_number_statistics_for_global_limit(398, Some(usize::MAX)).await?;
703 assert_eq!(row_count, Precision::Exact(2));
704
705 let row_count =
706 row_number_inexact_statistics_for_global_limit(0, Some(10)).await?;
707 assert_eq!(row_count, Precision::Inexact(10));
708
709 let row_count =
710 row_number_inexact_statistics_for_global_limit(5, Some(10)).await?;
711 assert_eq!(row_count, Precision::Inexact(10));
712
713 let row_count =
714 row_number_inexact_statistics_for_global_limit(400, Some(10)).await?;
715 assert_eq!(row_count, Precision::Exact(0));
716
717 let row_count =
718 row_number_inexact_statistics_for_global_limit(398, Some(10)).await?;
719 assert_eq!(row_count, Precision::Inexact(2));
720
721 let row_count =
722 row_number_inexact_statistics_for_global_limit(398, Some(1)).await?;
723 assert_eq!(row_count, Precision::Inexact(1));
724
725 let row_count = row_number_inexact_statistics_for_global_limit(398, None).await?;
726 assert_eq!(row_count, Precision::Inexact(2));
727
728 let row_count =
729 row_number_inexact_statistics_for_global_limit(0, Some(usize::MAX)).await?;
730 assert_eq!(row_count, Precision::Inexact(400));
731
732 let row_count =
733 row_number_inexact_statistics_for_global_limit(398, Some(usize::MAX)).await?;
734 assert_eq!(row_count, Precision::Inexact(2));
735
736 Ok(())
737 }
738
739 #[tokio::test]
740 async fn test_row_number_statistics_for_local_limit() -> Result<()> {
741 let row_count = row_number_statistics_for_local_limit(4, 10).await?;
742 assert_eq!(row_count, Precision::Exact(10));
743
744 Ok(())
745 }
746
747 async fn row_number_statistics_for_global_limit(
748 skip: usize,
749 fetch: Option<usize>,
750 ) -> Result<Precision<usize>> {
751 let num_partitions = 4;
752 let csv = test::scan_partitioned(num_partitions);
753
754 assert_eq!(csv.output_partitioning().partition_count(), num_partitions);
755
756 let offset =
757 GlobalLimitExec::new(Arc::new(CoalescePartitionsExec::new(csv)), skip, fetch);
758
759 Ok(offset.statistics()?.num_rows)
760 }
761
762 pub fn build_group_by(
763 input_schema: &SchemaRef,
764 columns: Vec<String>,
765 ) -> PhysicalGroupBy {
766 let mut group_by_expr: Vec<(Arc<dyn PhysicalExpr>, String)> = vec![];
767 for column in columns.iter() {
768 group_by_expr.push((col(column, input_schema).unwrap(), column.to_string()));
769 }
770 PhysicalGroupBy::new_single(group_by_expr.clone())
771 }
772
773 async fn row_number_inexact_statistics_for_global_limit(
774 skip: usize,
775 fetch: Option<usize>,
776 ) -> Result<Precision<usize>> {
777 let num_partitions = 4;
778 let csv = test::scan_partitioned(num_partitions);
779
780 assert_eq!(csv.output_partitioning().partition_count(), num_partitions);
781
782 let agg = AggregateExec::try_new(
784 AggregateMode::Final,
785 build_group_by(&csv.schema(), vec!["i".to_string()]),
786 vec![],
787 vec![],
788 Arc::clone(&csv),
789 Arc::clone(&csv.schema()),
790 )?;
791 let agg_exec: Arc<dyn ExecutionPlan> = Arc::new(agg);
792
793 let offset = GlobalLimitExec::new(
794 Arc::new(CoalescePartitionsExec::new(agg_exec)),
795 skip,
796 fetch,
797 );
798
799 Ok(offset.statistics()?.num_rows)
800 }
801
802 async fn row_number_statistics_for_local_limit(
803 num_partitions: usize,
804 fetch: usize,
805 ) -> Result<Precision<usize>> {
806 let csv = test::scan_partitioned(num_partitions);
807
808 assert_eq!(csv.output_partitioning().partition_count(), num_partitions);
809
810 let offset = LocalLimitExec::new(csv, fetch);
811
812 Ok(offset.statistics()?.num_rows)
813 }
814
815 fn make_batch_no_column(sz: usize) -> RecordBatch {
817 let schema = Arc::new(Schema::empty());
818
819 let options = RecordBatchOptions::new().with_row_count(Option::from(sz));
820 RecordBatch::try_new_with_options(schema, vec![], &options).unwrap()
821 }
822}