datafusion_physical_plan/
stream.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Stream wrappers for physical operators
19
20use std::pin::Pin;
21use std::sync::Arc;
22use std::task::Context;
23use std::task::Poll;
24
25use super::metrics::BaselineMetrics;
26use super::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream};
27use crate::displayable;
28
29use arrow::{datatypes::SchemaRef, record_batch::RecordBatch};
30use datafusion_common::{internal_err, Result};
31use datafusion_execution::TaskContext;
32
33use futures::stream::BoxStream;
34use futures::{Future, Stream, StreamExt};
35use log::debug;
36use pin_project_lite::pin_project;
37use tokio::sync::mpsc::{Receiver, Sender};
38use tokio::task::JoinSet;
39
40/// Creates a stream from a collection of producing tasks, routing panics to the stream.
41///
42/// Note that this is similar to  [`ReceiverStream` from tokio-stream], with the differences being:
43///
44/// 1. Methods to bound and "detach"  tasks (`spawn()` and `spawn_blocking()`).
45///
46/// 2. Propagates panics, whereas the `tokio` version doesn't propagate panics to the receiver.
47///
48/// 3. Automatically cancels any outstanding tasks when the receiver stream is dropped.
49///
50/// [`ReceiverStream` from tokio-stream]: https://docs.rs/tokio-stream/latest/tokio_stream/wrappers/struct.ReceiverStream.html
51pub(crate) struct ReceiverStreamBuilder<O> {
52    tx: Sender<Result<O>>,
53    rx: Receiver<Result<O>>,
54    join_set: JoinSet<Result<()>>,
55}
56
57impl<O: Send + 'static> ReceiverStreamBuilder<O> {
58    /// Create new channels with the specified buffer size
59    pub fn new(capacity: usize) -> Self {
60        let (tx, rx) = tokio::sync::mpsc::channel(capacity);
61
62        Self {
63            tx,
64            rx,
65            join_set: JoinSet::new(),
66        }
67    }
68
69    /// Get a handle for sending data to the output
70    pub fn tx(&self) -> Sender<Result<O>> {
71        self.tx.clone()
72    }
73
74    /// Spawn task that will be aborted if this builder (or the stream
75    /// built from it) are dropped
76    pub fn spawn<F>(&mut self, task: F)
77    where
78        F: Future<Output = Result<()>>,
79        F: Send + 'static,
80    {
81        self.join_set.spawn(task);
82    }
83
84    /// Spawn a blocking task that will be aborted if this builder (or the stream
85    /// built from it) are dropped.
86    ///
87    /// This is often used to spawn tasks that write to the sender
88    /// retrieved from `Self::tx`.
89    pub fn spawn_blocking<F>(&mut self, f: F)
90    where
91        F: FnOnce() -> Result<()>,
92        F: Send + 'static,
93    {
94        self.join_set.spawn_blocking(f);
95    }
96
97    /// Create a stream of all data written to `tx`
98    pub fn build(self) -> BoxStream<'static, Result<O>> {
99        let Self {
100            tx,
101            rx,
102            mut join_set,
103        } = self;
104
105        // Doesn't need tx
106        drop(tx);
107
108        // future that checks the result of the join set, and propagates panic if seen
109        let check = async move {
110            while let Some(result) = join_set.join_next().await {
111                match result {
112                    Ok(task_result) => {
113                        match task_result {
114                            // Nothing to report
115                            Ok(_) => continue,
116                            // This means a blocking task error
117                            Err(error) => return Some(Err(error)),
118                        }
119                    }
120                    // This means a tokio task error, likely a panic
121                    Err(e) => {
122                        if e.is_panic() {
123                            // resume on the main thread
124                            std::panic::resume_unwind(e.into_panic());
125                        } else {
126                            // This should only occur if the task is
127                            // cancelled, which would only occur if
128                            // the JoinSet were aborted, which in turn
129                            // would imply that the receiver has been
130                            // dropped and this code is not running
131                            return Some(internal_err!("Non Panic Task error: {e}"));
132                        }
133                    }
134                }
135            }
136            None
137        };
138
139        let check_stream = futures::stream::once(check)
140            // unwrap Option / only return the error
141            .filter_map(|item| async move { item });
142
143        // Convert the receiver into a stream
144        let rx_stream = futures::stream::unfold(rx, |mut rx| async move {
145            let next_item = rx.recv().await;
146            next_item.map(|next_item| (next_item, rx))
147        });
148
149        // Merge the streams together so whichever is ready first
150        // produces the batch
151        futures::stream::select(rx_stream, check_stream).boxed()
152    }
153}
154
155/// Builder for `RecordBatchReceiverStream` that propagates errors
156/// and panic's correctly.
157///
158/// [`RecordBatchReceiverStreamBuilder`] is used to spawn one or more tasks
159/// that produce [`RecordBatch`]es and send them to a single
160/// `Receiver` which can improve parallelism.
161///
162/// This also handles propagating panic`s and canceling the tasks.
163///
164/// # Example
165///
166/// The following example spawns 2 tasks that will write [`RecordBatch`]es to
167/// the `tx` end of the builder, after building the stream, we can receive
168/// those batches with calling `.next()`
169///
170/// ```
171/// # use std::sync::Arc;
172/// # use datafusion_common::arrow::datatypes::{Schema, Field, DataType};
173/// # use datafusion_common::arrow::array::RecordBatch;
174/// # use datafusion_physical_plan::stream::RecordBatchReceiverStreamBuilder;
175/// # use futures::stream::StreamExt;
176/// # use tokio::runtime::Builder;
177/// # let rt = Builder::new_current_thread().build().unwrap();
178/// #
179/// # rt.block_on(async {
180/// let schema = Arc::new(Schema::new(vec![Field::new("foo", DataType::Int8, false)]));
181/// let mut builder = RecordBatchReceiverStreamBuilder::new(Arc::clone(&schema), 10);
182///
183/// // task 1
184/// let tx_1 = builder.tx();
185/// let schema_1 = Arc::clone(&schema);
186/// builder.spawn(async move {
187///     // Your task needs to send batches to the tx
188///     tx_1.send(Ok(RecordBatch::new_empty(schema_1))).await.unwrap();
189///
190///     Ok(())
191/// });
192///
193/// // task 2
194/// let tx_2 = builder.tx();
195/// let schema_2 = Arc::clone(&schema);
196/// builder.spawn(async move {
197///     // Your task needs to send batches to the tx
198///     tx_2.send(Ok(RecordBatch::new_empty(schema_2))).await.unwrap();
199///
200///     Ok(())
201/// });
202///
203/// let mut stream = builder.build();
204/// while let Some(res_batch) = stream.next().await {
205///     // `res_batch` can either from task 1 or 2
206///
207///     // do something with `res_batch`
208/// }
209/// # });
210/// ```
211pub struct RecordBatchReceiverStreamBuilder {
212    schema: SchemaRef,
213    inner: ReceiverStreamBuilder<RecordBatch>,
214}
215
216impl RecordBatchReceiverStreamBuilder {
217    /// Create new channels with the specified buffer size
218    pub fn new(schema: SchemaRef, capacity: usize) -> Self {
219        Self {
220            schema,
221            inner: ReceiverStreamBuilder::new(capacity),
222        }
223    }
224
225    /// Get a handle for sending [`RecordBatch`] to the output
226    pub fn tx(&self) -> Sender<Result<RecordBatch>> {
227        self.inner.tx()
228    }
229
230    /// Spawn task that will be aborted if this builder (or the stream
231    /// built from it) are dropped
232    ///
233    /// This is often used to spawn tasks that write to the sender
234    /// retrieved from [`Self::tx`], for examples, see the document
235    /// of this type.
236    pub fn spawn<F>(&mut self, task: F)
237    where
238        F: Future<Output = Result<()>>,
239        F: Send + 'static,
240    {
241        self.inner.spawn(task)
242    }
243
244    /// Spawn a blocking task that will be aborted if this builder (or the stream
245    /// built from it) are dropped
246    ///
247    /// This is often used to spawn tasks that write to the sender
248    /// retrieved from [`Self::tx`], for examples, see the document
249    /// of this type.
250    pub fn spawn_blocking<F>(&mut self, f: F)
251    where
252        F: FnOnce() -> Result<()>,
253        F: Send + 'static,
254    {
255        self.inner.spawn_blocking(f)
256    }
257
258    /// Runs the `partition` of the `input` ExecutionPlan on the
259    /// tokio thread pool and writes its outputs to this stream
260    ///
261    /// If the input partition produces an error, the error will be
262    /// sent to the output stream and no further results are sent.
263    pub(crate) fn run_input(
264        &mut self,
265        input: Arc<dyn ExecutionPlan>,
266        partition: usize,
267        context: Arc<TaskContext>,
268    ) {
269        let output = self.tx();
270
271        self.inner.spawn(async move {
272            let mut stream = match input.execute(partition, context) {
273                Err(e) => {
274                    // If send fails, the plan being torn down, there
275                    // is no place to send the error and no reason to continue.
276                    output.send(Err(e)).await.ok();
277                    debug!(
278                        "Stopping execution: error executing input: {}",
279                        displayable(input.as_ref()).one_line()
280                    );
281                    return Ok(());
282                }
283                Ok(stream) => stream,
284            };
285
286            // Transfer batches from inner stream to the output tx
287            // immediately.
288            while let Some(item) = stream.next().await {
289                let is_err = item.is_err();
290
291                // If send fails, plan being torn down, there is no
292                // place to send the error and no reason to continue.
293                if output.send(item).await.is_err() {
294                    debug!(
295                        "Stopping execution: output is gone, plan cancelling: {}",
296                        displayable(input.as_ref()).one_line()
297                    );
298                    return Ok(());
299                }
300
301                // Stop after the first error is encountered (Don't
302                // drive all streams to completion)
303                if is_err {
304                    debug!(
305                        "Stopping execution: plan returned error: {}",
306                        displayable(input.as_ref()).one_line()
307                    );
308                    return Ok(());
309                }
310            }
311
312            Ok(())
313        });
314    }
315
316    /// Create a stream of all [`RecordBatch`] written to `tx`
317    pub fn build(self) -> SendableRecordBatchStream {
318        Box::pin(RecordBatchStreamAdapter::new(
319            self.schema,
320            self.inner.build(),
321        ))
322    }
323}
324
325#[doc(hidden)]
326pub struct RecordBatchReceiverStream {}
327
328impl RecordBatchReceiverStream {
329    /// Create a builder with an internal buffer of capacity batches.
330    pub fn builder(
331        schema: SchemaRef,
332        capacity: usize,
333    ) -> RecordBatchReceiverStreamBuilder {
334        RecordBatchReceiverStreamBuilder::new(schema, capacity)
335    }
336}
337
338pin_project! {
339    /// Combines a [`Stream`] with a [`SchemaRef`] implementing
340    /// [`SendableRecordBatchStream`] for the combination
341    ///
342    /// See [`Self::new`] for an example
343    pub struct RecordBatchStreamAdapter<S> {
344        schema: SchemaRef,
345
346        #[pin]
347        stream: S,
348    }
349}
350
351impl<S> RecordBatchStreamAdapter<S> {
352    /// Creates a new [`RecordBatchStreamAdapter`] from the provided schema and stream.
353    ///
354    /// Note to create a [`SendableRecordBatchStream`] you pin the result
355    ///
356    /// # Example
357    /// ```
358    /// # use arrow::array::record_batch;
359    /// # use datafusion_execution::SendableRecordBatchStream;
360    /// # use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
361    /// // Create stream of Result<RecordBatch>
362    /// let batch = record_batch!(
363    ///   ("a", Int32, [1, 2, 3]),
364    ///   ("b", Float64, [Some(4.0), None, Some(5.0)])
365    /// ).expect("created batch");
366    /// let schema = batch.schema();
367    /// let stream = futures::stream::iter(vec![Ok(batch)]);
368    /// // Convert the stream to a SendableRecordBatchStream
369    /// let adapter = RecordBatchStreamAdapter::new(schema, stream);
370    /// // Now you can use the adapter as a SendableRecordBatchStream
371    /// let batch_stream: SendableRecordBatchStream = Box::pin(adapter);
372    /// // ...
373    /// ```
374    pub fn new(schema: SchemaRef, stream: S) -> Self {
375        Self { schema, stream }
376    }
377}
378
379impl<S> std::fmt::Debug for RecordBatchStreamAdapter<S> {
380    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
381        f.debug_struct("RecordBatchStreamAdapter")
382            .field("schema", &self.schema)
383            .finish()
384    }
385}
386
387impl<S> Stream for RecordBatchStreamAdapter<S>
388where
389    S: Stream<Item = Result<RecordBatch>>,
390{
391    type Item = Result<RecordBatch>;
392
393    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
394        self.project().stream.poll_next(cx)
395    }
396
397    fn size_hint(&self) -> (usize, Option<usize>) {
398        self.stream.size_hint()
399    }
400}
401
402impl<S> RecordBatchStream for RecordBatchStreamAdapter<S>
403where
404    S: Stream<Item = Result<RecordBatch>>,
405{
406    fn schema(&self) -> SchemaRef {
407        Arc::clone(&self.schema)
408    }
409}
410
411/// `EmptyRecordBatchStream` can be used to create a [`RecordBatchStream`]
412/// that will produce no results
413pub struct EmptyRecordBatchStream {
414    /// Schema wrapped by Arc
415    schema: SchemaRef,
416}
417
418impl EmptyRecordBatchStream {
419    /// Create an empty RecordBatchStream
420    pub fn new(schema: SchemaRef) -> Self {
421        Self { schema }
422    }
423}
424
425impl RecordBatchStream for EmptyRecordBatchStream {
426    fn schema(&self) -> SchemaRef {
427        Arc::clone(&self.schema)
428    }
429}
430
431impl Stream for EmptyRecordBatchStream {
432    type Item = Result<RecordBatch>;
433
434    fn poll_next(
435        self: Pin<&mut Self>,
436        _cx: &mut Context<'_>,
437    ) -> Poll<Option<Self::Item>> {
438        Poll::Ready(None)
439    }
440}
441
442/// Stream wrapper that records `BaselineMetrics` for a particular
443/// `[SendableRecordBatchStream]` (likely a partition)
444pub(crate) struct ObservedStream {
445    inner: SendableRecordBatchStream,
446    baseline_metrics: BaselineMetrics,
447    fetch: Option<usize>,
448    produced: usize,
449}
450
451impl ObservedStream {
452    pub fn new(
453        inner: SendableRecordBatchStream,
454        baseline_metrics: BaselineMetrics,
455        fetch: Option<usize>,
456    ) -> Self {
457        Self {
458            inner,
459            baseline_metrics,
460            fetch,
461            produced: 0,
462        }
463    }
464
465    fn limit_reached(
466        &mut self,
467        poll: Poll<Option<Result<RecordBatch>>>,
468    ) -> Poll<Option<Result<RecordBatch>>> {
469        let Some(fetch) = self.fetch else { return poll };
470
471        if self.produced >= fetch {
472            return Poll::Ready(None);
473        }
474
475        if let Poll::Ready(Some(Ok(batch))) = &poll {
476            if self.produced + batch.num_rows() > fetch {
477                let batch = batch.slice(0, fetch.saturating_sub(self.produced));
478                self.produced += batch.num_rows();
479                return Poll::Ready(Some(Ok(batch)));
480            };
481            self.produced += batch.num_rows()
482        }
483        poll
484    }
485}
486
487impl RecordBatchStream for ObservedStream {
488    fn schema(&self) -> SchemaRef {
489        self.inner.schema()
490    }
491}
492
493impl Stream for ObservedStream {
494    type Item = Result<RecordBatch>;
495
496    fn poll_next(
497        mut self: Pin<&mut Self>,
498        cx: &mut Context<'_>,
499    ) -> Poll<Option<Self::Item>> {
500        let mut poll = self.inner.poll_next_unpin(cx);
501        if self.fetch.is_some() {
502            poll = self.limit_reached(poll);
503        }
504        self.baseline_metrics.record_poll(poll)
505    }
506}
507
508#[cfg(test)]
509mod test {
510    use super::*;
511    use crate::test::exec::{
512        assert_strong_count_converges_to_zero, BlockingExec, MockExec, PanicExec,
513    };
514
515    use arrow::datatypes::{DataType, Field, Schema};
516    use datafusion_common::exec_err;
517
518    fn schema() -> SchemaRef {
519        Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]))
520    }
521
522    #[tokio::test]
523    #[should_panic(expected = "PanickingStream did panic")]
524    async fn record_batch_receiver_stream_propagates_panics() {
525        let schema = schema();
526
527        let num_partitions = 10;
528        let input = PanicExec::new(Arc::clone(&schema), num_partitions);
529        consume(input, 10).await
530    }
531
532    #[tokio::test]
533    #[should_panic(expected = "PanickingStream did panic: 1")]
534    async fn record_batch_receiver_stream_propagates_panics_early_shutdown() {
535        let schema = schema();
536
537        // Make 2 partitions, second partition panics before the first
538        let num_partitions = 2;
539        let input = PanicExec::new(Arc::clone(&schema), num_partitions)
540            .with_partition_panic(0, 10)
541            .with_partition_panic(1, 3); // partition 1 should panic first (after 3 )
542
543        // Ensure that the panic results in an early shutdown (that
544        // everything stops after the first panic).
545
546        // Since the stream reads every other batch: (0,1,0,1,0,panic)
547        // so should not exceed 5 batches prior to the panic
548        let max_batches = 5;
549        consume(input, max_batches).await
550    }
551
552    #[tokio::test]
553    async fn record_batch_receiver_stream_drop_cancel() {
554        let task_ctx = Arc::new(TaskContext::default());
555        let schema = schema();
556
557        // Make an input that never proceeds
558        let input = BlockingExec::new(Arc::clone(&schema), 1);
559        let refs = input.refs();
560
561        // Configure a RecordBatchReceiverStream to consume the input
562        let mut builder = RecordBatchReceiverStream::builder(schema, 2);
563        builder.run_input(Arc::new(input), 0, Arc::clone(&task_ctx));
564        let stream = builder.build();
565
566        // Input should still be present
567        assert!(std::sync::Weak::strong_count(&refs) > 0);
568
569        // Drop the stream, ensure the refs go to zero
570        drop(stream);
571        assert_strong_count_converges_to_zero(refs).await;
572    }
573
574    #[tokio::test]
575    /// Ensure that if an error is received in one stream, the
576    /// `RecordBatchReceiverStream` stops early and does not drive
577    /// other streams to completion.
578    async fn record_batch_receiver_stream_error_does_not_drive_completion() {
579        let task_ctx = Arc::new(TaskContext::default());
580        let schema = schema();
581
582        // make an input that will error twice
583        let error_stream = MockExec::new(
584            vec![exec_err!("Test1"), exec_err!("Test2")],
585            Arc::clone(&schema),
586        )
587        .with_use_task(false);
588
589        let mut builder = RecordBatchReceiverStream::builder(schema, 2);
590        builder.run_input(Arc::new(error_stream), 0, Arc::clone(&task_ctx));
591        let mut stream = builder.build();
592
593        // Get the first result, which should be an error
594        let first_batch = stream.next().await.unwrap();
595        let first_err = first_batch.unwrap_err();
596        assert_eq!(first_err.strip_backtrace(), "Execution error: Test1");
597
598        // There should be no more batches produced (should not get the second error)
599        assert!(stream.next().await.is_none());
600    }
601
602    /// Consumes all the input's partitions into a
603    /// RecordBatchReceiverStream and runs it to completion
604    ///
605    /// panic's if more than max_batches is seen,
606    async fn consume(input: PanicExec, max_batches: usize) {
607        let task_ctx = Arc::new(TaskContext::default());
608
609        let input = Arc::new(input);
610        let num_partitions = input.properties().output_partitioning().partition_count();
611
612        // Configure a RecordBatchReceiverStream to consume all the input partitions
613        let mut builder =
614            RecordBatchReceiverStream::builder(input.schema(), num_partitions);
615        for partition in 0..num_partitions {
616            builder.run_input(
617                Arc::clone(&input) as Arc<dyn ExecutionPlan>,
618                partition,
619                Arc::clone(&task_ctx),
620            );
621        }
622        let mut stream = builder.build();
623
624        // Drain the stream until it is complete, panic'ing on error
625        let mut num_batches = 0;
626        while let Some(next) = stream.next().await {
627            next.unwrap();
628            num_batches += 1;
629            assert!(
630                num_batches < max_batches,
631                "Got the limit of {num_batches} batches before seeing panic"
632            );
633        }
634    }
635}