datafusion_physical_plan/stream.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Stream wrappers for physical operators
19
20use std::pin::Pin;
21use std::sync::Arc;
22use std::task::Context;
23use std::task::Poll;
24
25use super::metrics::BaselineMetrics;
26use super::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream};
27use crate::displayable;
28
29use arrow::{datatypes::SchemaRef, record_batch::RecordBatch};
30use datafusion_common::{internal_err, Result};
31use datafusion_execution::TaskContext;
32
33use futures::stream::BoxStream;
34use futures::{Future, Stream, StreamExt};
35use log::debug;
36use pin_project_lite::pin_project;
37use tokio::sync::mpsc::{Receiver, Sender};
38use tokio::task::JoinSet;
39
40/// Creates a stream from a collection of producing tasks, routing panics to the stream.
41///
42/// Note that this is similar to [`ReceiverStream` from tokio-stream], with the differences being:
43///
44/// 1. Methods to bound and "detach" tasks (`spawn()` and `spawn_blocking()`).
45///
46/// 2. Propagates panics, whereas the `tokio` version doesn't propagate panics to the receiver.
47///
48/// 3. Automatically cancels any outstanding tasks when the receiver stream is dropped.
49///
50/// [`ReceiverStream` from tokio-stream]: https://docs.rs/tokio-stream/latest/tokio_stream/wrappers/struct.ReceiverStream.html
51pub(crate) struct ReceiverStreamBuilder<O> {
52 tx: Sender<Result<O>>,
53 rx: Receiver<Result<O>>,
54 join_set: JoinSet<Result<()>>,
55}
56
57impl<O: Send + 'static> ReceiverStreamBuilder<O> {
58 /// Create new channels with the specified buffer size
59 pub fn new(capacity: usize) -> Self {
60 let (tx, rx) = tokio::sync::mpsc::channel(capacity);
61
62 Self {
63 tx,
64 rx,
65 join_set: JoinSet::new(),
66 }
67 }
68
69 /// Get a handle for sending data to the output
70 pub fn tx(&self) -> Sender<Result<O>> {
71 self.tx.clone()
72 }
73
74 /// Spawn task that will be aborted if this builder (or the stream
75 /// built from it) are dropped
76 pub fn spawn<F>(&mut self, task: F)
77 where
78 F: Future<Output = Result<()>>,
79 F: Send + 'static,
80 {
81 self.join_set.spawn(task);
82 }
83
84 /// Spawn a blocking task that will be aborted if this builder (or the stream
85 /// built from it) are dropped.
86 ///
87 /// This is often used to spawn tasks that write to the sender
88 /// retrieved from `Self::tx`.
89 pub fn spawn_blocking<F>(&mut self, f: F)
90 where
91 F: FnOnce() -> Result<()>,
92 F: Send + 'static,
93 {
94 self.join_set.spawn_blocking(f);
95 }
96
97 /// Create a stream of all data written to `tx`
98 pub fn build(self) -> BoxStream<'static, Result<O>> {
99 let Self {
100 tx,
101 rx,
102 mut join_set,
103 } = self;
104
105 // Doesn't need tx
106 drop(tx);
107
108 // future that checks the result of the join set, and propagates panic if seen
109 let check = async move {
110 while let Some(result) = join_set.join_next().await {
111 match result {
112 Ok(task_result) => {
113 match task_result {
114 // Nothing to report
115 Ok(_) => continue,
116 // This means a blocking task error
117 Err(error) => return Some(Err(error)),
118 }
119 }
120 // This means a tokio task error, likely a panic
121 Err(e) => {
122 if e.is_panic() {
123 // resume on the main thread
124 std::panic::resume_unwind(e.into_panic());
125 } else {
126 // This should only occur if the task is
127 // cancelled, which would only occur if
128 // the JoinSet were aborted, which in turn
129 // would imply that the receiver has been
130 // dropped and this code is not running
131 return Some(internal_err!("Non Panic Task error: {e}"));
132 }
133 }
134 }
135 }
136 None
137 };
138
139 let check_stream = futures::stream::once(check)
140 // unwrap Option / only return the error
141 .filter_map(|item| async move { item });
142
143 // Convert the receiver into a stream
144 let rx_stream = futures::stream::unfold(rx, |mut rx| async move {
145 let next_item = rx.recv().await;
146 next_item.map(|next_item| (next_item, rx))
147 });
148
149 // Merge the streams together so whichever is ready first
150 // produces the batch
151 futures::stream::select(rx_stream, check_stream).boxed()
152 }
153}
154
155/// Builder for `RecordBatchReceiverStream` that propagates errors
156/// and panic's correctly.
157///
158/// [`RecordBatchReceiverStreamBuilder`] is used to spawn one or more tasks
159/// that produce [`RecordBatch`]es and send them to a single
160/// `Receiver` which can improve parallelism.
161///
162/// This also handles propagating panic`s and canceling the tasks.
163///
164/// # Example
165///
166/// The following example spawns 2 tasks that will write [`RecordBatch`]es to
167/// the `tx` end of the builder, after building the stream, we can receive
168/// those batches with calling `.next()`
169///
170/// ```
171/// # use std::sync::Arc;
172/// # use datafusion_common::arrow::datatypes::{Schema, Field, DataType};
173/// # use datafusion_common::arrow::array::RecordBatch;
174/// # use datafusion_physical_plan::stream::RecordBatchReceiverStreamBuilder;
175/// # use futures::stream::StreamExt;
176/// # use tokio::runtime::Builder;
177/// # let rt = Builder::new_current_thread().build().unwrap();
178/// #
179/// # rt.block_on(async {
180/// let schema = Arc::new(Schema::new(vec![Field::new("foo", DataType::Int8, false)]));
181/// let mut builder = RecordBatchReceiverStreamBuilder::new(Arc::clone(&schema), 10);
182///
183/// // task 1
184/// let tx_1 = builder.tx();
185/// let schema_1 = Arc::clone(&schema);
186/// builder.spawn(async move {
187/// // Your task needs to send batches to the tx
188/// tx_1.send(Ok(RecordBatch::new_empty(schema_1))).await.unwrap();
189///
190/// Ok(())
191/// });
192///
193/// // task 2
194/// let tx_2 = builder.tx();
195/// let schema_2 = Arc::clone(&schema);
196/// builder.spawn(async move {
197/// // Your task needs to send batches to the tx
198/// tx_2.send(Ok(RecordBatch::new_empty(schema_2))).await.unwrap();
199///
200/// Ok(())
201/// });
202///
203/// let mut stream = builder.build();
204/// while let Some(res_batch) = stream.next().await {
205/// // `res_batch` can either from task 1 or 2
206///
207/// // do something with `res_batch`
208/// }
209/// # });
210/// ```
211pub struct RecordBatchReceiverStreamBuilder {
212 schema: SchemaRef,
213 inner: ReceiverStreamBuilder<RecordBatch>,
214}
215
216impl RecordBatchReceiverStreamBuilder {
217 /// Create new channels with the specified buffer size
218 pub fn new(schema: SchemaRef, capacity: usize) -> Self {
219 Self {
220 schema,
221 inner: ReceiverStreamBuilder::new(capacity),
222 }
223 }
224
225 /// Get a handle for sending [`RecordBatch`] to the output
226 pub fn tx(&self) -> Sender<Result<RecordBatch>> {
227 self.inner.tx()
228 }
229
230 /// Spawn task that will be aborted if this builder (or the stream
231 /// built from it) are dropped
232 ///
233 /// This is often used to spawn tasks that write to the sender
234 /// retrieved from [`Self::tx`], for examples, see the document
235 /// of this type.
236 pub fn spawn<F>(&mut self, task: F)
237 where
238 F: Future<Output = Result<()>>,
239 F: Send + 'static,
240 {
241 self.inner.spawn(task)
242 }
243
244 /// Spawn a blocking task that will be aborted if this builder (or the stream
245 /// built from it) are dropped
246 ///
247 /// This is often used to spawn tasks that write to the sender
248 /// retrieved from [`Self::tx`], for examples, see the document
249 /// of this type.
250 pub fn spawn_blocking<F>(&mut self, f: F)
251 where
252 F: FnOnce() -> Result<()>,
253 F: Send + 'static,
254 {
255 self.inner.spawn_blocking(f)
256 }
257
258 /// Runs the `partition` of the `input` ExecutionPlan on the
259 /// tokio thread pool and writes its outputs to this stream
260 ///
261 /// If the input partition produces an error, the error will be
262 /// sent to the output stream and no further results are sent.
263 pub(crate) fn run_input(
264 &mut self,
265 input: Arc<dyn ExecutionPlan>,
266 partition: usize,
267 context: Arc<TaskContext>,
268 ) {
269 let output = self.tx();
270
271 self.inner.spawn(async move {
272 let mut stream = match input.execute(partition, context) {
273 Err(e) => {
274 // If send fails, the plan being torn down, there
275 // is no place to send the error and no reason to continue.
276 output.send(Err(e)).await.ok();
277 debug!(
278 "Stopping execution: error executing input: {}",
279 displayable(input.as_ref()).one_line()
280 );
281 return Ok(());
282 }
283 Ok(stream) => stream,
284 };
285
286 // Transfer batches from inner stream to the output tx
287 // immediately.
288 while let Some(item) = stream.next().await {
289 let is_err = item.is_err();
290
291 // If send fails, plan being torn down, there is no
292 // place to send the error and no reason to continue.
293 if output.send(item).await.is_err() {
294 debug!(
295 "Stopping execution: output is gone, plan cancelling: {}",
296 displayable(input.as_ref()).one_line()
297 );
298 return Ok(());
299 }
300
301 // Stop after the first error is encountered (Don't
302 // drive all streams to completion)
303 if is_err {
304 debug!(
305 "Stopping execution: plan returned error: {}",
306 displayable(input.as_ref()).one_line()
307 );
308 return Ok(());
309 }
310 }
311
312 Ok(())
313 });
314 }
315
316 /// Create a stream of all [`RecordBatch`] written to `tx`
317 pub fn build(self) -> SendableRecordBatchStream {
318 Box::pin(RecordBatchStreamAdapter::new(
319 self.schema,
320 self.inner.build(),
321 ))
322 }
323}
324
325#[doc(hidden)]
326pub struct RecordBatchReceiverStream {}
327
328impl RecordBatchReceiverStream {
329 /// Create a builder with an internal buffer of capacity batches.
330 pub fn builder(
331 schema: SchemaRef,
332 capacity: usize,
333 ) -> RecordBatchReceiverStreamBuilder {
334 RecordBatchReceiverStreamBuilder::new(schema, capacity)
335 }
336}
337
338pin_project! {
339 /// Combines a [`Stream`] with a [`SchemaRef`] implementing
340 /// [`SendableRecordBatchStream`] for the combination
341 ///
342 /// See [`Self::new`] for an example
343 pub struct RecordBatchStreamAdapter<S> {
344 schema: SchemaRef,
345
346 #[pin]
347 stream: S,
348 }
349}
350
351impl<S> RecordBatchStreamAdapter<S> {
352 /// Creates a new [`RecordBatchStreamAdapter`] from the provided schema and stream.
353 ///
354 /// Note to create a [`SendableRecordBatchStream`] you pin the result
355 ///
356 /// # Example
357 /// ```
358 /// # use arrow::array::record_batch;
359 /// # use datafusion_execution::SendableRecordBatchStream;
360 /// # use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
361 /// // Create stream of Result<RecordBatch>
362 /// let batch = record_batch!(
363 /// ("a", Int32, [1, 2, 3]),
364 /// ("b", Float64, [Some(4.0), None, Some(5.0)])
365 /// ).expect("created batch");
366 /// let schema = batch.schema();
367 /// let stream = futures::stream::iter(vec![Ok(batch)]);
368 /// // Convert the stream to a SendableRecordBatchStream
369 /// let adapter = RecordBatchStreamAdapter::new(schema, stream);
370 /// // Now you can use the adapter as a SendableRecordBatchStream
371 /// let batch_stream: SendableRecordBatchStream = Box::pin(adapter);
372 /// // ...
373 /// ```
374 pub fn new(schema: SchemaRef, stream: S) -> Self {
375 Self { schema, stream }
376 }
377}
378
379impl<S> std::fmt::Debug for RecordBatchStreamAdapter<S> {
380 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
381 f.debug_struct("RecordBatchStreamAdapter")
382 .field("schema", &self.schema)
383 .finish()
384 }
385}
386
387impl<S> Stream for RecordBatchStreamAdapter<S>
388where
389 S: Stream<Item = Result<RecordBatch>>,
390{
391 type Item = Result<RecordBatch>;
392
393 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
394 self.project().stream.poll_next(cx)
395 }
396
397 fn size_hint(&self) -> (usize, Option<usize>) {
398 self.stream.size_hint()
399 }
400}
401
402impl<S> RecordBatchStream for RecordBatchStreamAdapter<S>
403where
404 S: Stream<Item = Result<RecordBatch>>,
405{
406 fn schema(&self) -> SchemaRef {
407 Arc::clone(&self.schema)
408 }
409}
410
411/// `EmptyRecordBatchStream` can be used to create a [`RecordBatchStream`]
412/// that will produce no results
413pub struct EmptyRecordBatchStream {
414 /// Schema wrapped by Arc
415 schema: SchemaRef,
416}
417
418impl EmptyRecordBatchStream {
419 /// Create an empty RecordBatchStream
420 pub fn new(schema: SchemaRef) -> Self {
421 Self { schema }
422 }
423}
424
425impl RecordBatchStream for EmptyRecordBatchStream {
426 fn schema(&self) -> SchemaRef {
427 Arc::clone(&self.schema)
428 }
429}
430
431impl Stream for EmptyRecordBatchStream {
432 type Item = Result<RecordBatch>;
433
434 fn poll_next(
435 self: Pin<&mut Self>,
436 _cx: &mut Context<'_>,
437 ) -> Poll<Option<Self::Item>> {
438 Poll::Ready(None)
439 }
440}
441
442/// Stream wrapper that records `BaselineMetrics` for a particular
443/// `[SendableRecordBatchStream]` (likely a partition)
444pub(crate) struct ObservedStream {
445 inner: SendableRecordBatchStream,
446 baseline_metrics: BaselineMetrics,
447 fetch: Option<usize>,
448 produced: usize,
449}
450
451impl ObservedStream {
452 pub fn new(
453 inner: SendableRecordBatchStream,
454 baseline_metrics: BaselineMetrics,
455 fetch: Option<usize>,
456 ) -> Self {
457 Self {
458 inner,
459 baseline_metrics,
460 fetch,
461 produced: 0,
462 }
463 }
464
465 fn limit_reached(
466 &mut self,
467 poll: Poll<Option<Result<RecordBatch>>>,
468 ) -> Poll<Option<Result<RecordBatch>>> {
469 let Some(fetch) = self.fetch else { return poll };
470
471 if self.produced >= fetch {
472 return Poll::Ready(None);
473 }
474
475 if let Poll::Ready(Some(Ok(batch))) = &poll {
476 if self.produced + batch.num_rows() > fetch {
477 let batch = batch.slice(0, fetch.saturating_sub(self.produced));
478 self.produced += batch.num_rows();
479 return Poll::Ready(Some(Ok(batch)));
480 };
481 self.produced += batch.num_rows()
482 }
483 poll
484 }
485}
486
487impl RecordBatchStream for ObservedStream {
488 fn schema(&self) -> SchemaRef {
489 self.inner.schema()
490 }
491}
492
493impl Stream for ObservedStream {
494 type Item = Result<RecordBatch>;
495
496 fn poll_next(
497 mut self: Pin<&mut Self>,
498 cx: &mut Context<'_>,
499 ) -> Poll<Option<Self::Item>> {
500 let mut poll = self.inner.poll_next_unpin(cx);
501 if self.fetch.is_some() {
502 poll = self.limit_reached(poll);
503 }
504 self.baseline_metrics.record_poll(poll)
505 }
506}
507
508#[cfg(test)]
509mod test {
510 use super::*;
511 use crate::test::exec::{
512 assert_strong_count_converges_to_zero, BlockingExec, MockExec, PanicExec,
513 };
514
515 use arrow::datatypes::{DataType, Field, Schema};
516 use datafusion_common::exec_err;
517
518 fn schema() -> SchemaRef {
519 Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]))
520 }
521
522 #[tokio::test]
523 #[should_panic(expected = "PanickingStream did panic")]
524 async fn record_batch_receiver_stream_propagates_panics() {
525 let schema = schema();
526
527 let num_partitions = 10;
528 let input = PanicExec::new(Arc::clone(&schema), num_partitions);
529 consume(input, 10).await
530 }
531
532 #[tokio::test]
533 #[should_panic(expected = "PanickingStream did panic: 1")]
534 async fn record_batch_receiver_stream_propagates_panics_early_shutdown() {
535 let schema = schema();
536
537 // Make 2 partitions, second partition panics before the first
538 let num_partitions = 2;
539 let input = PanicExec::new(Arc::clone(&schema), num_partitions)
540 .with_partition_panic(0, 10)
541 .with_partition_panic(1, 3); // partition 1 should panic first (after 3 )
542
543 // Ensure that the panic results in an early shutdown (that
544 // everything stops after the first panic).
545
546 // Since the stream reads every other batch: (0,1,0,1,0,panic)
547 // so should not exceed 5 batches prior to the panic
548 let max_batches = 5;
549 consume(input, max_batches).await
550 }
551
552 #[tokio::test]
553 async fn record_batch_receiver_stream_drop_cancel() {
554 let task_ctx = Arc::new(TaskContext::default());
555 let schema = schema();
556
557 // Make an input that never proceeds
558 let input = BlockingExec::new(Arc::clone(&schema), 1);
559 let refs = input.refs();
560
561 // Configure a RecordBatchReceiverStream to consume the input
562 let mut builder = RecordBatchReceiverStream::builder(schema, 2);
563 builder.run_input(Arc::new(input), 0, Arc::clone(&task_ctx));
564 let stream = builder.build();
565
566 // Input should still be present
567 assert!(std::sync::Weak::strong_count(&refs) > 0);
568
569 // Drop the stream, ensure the refs go to zero
570 drop(stream);
571 assert_strong_count_converges_to_zero(refs).await;
572 }
573
574 #[tokio::test]
575 /// Ensure that if an error is received in one stream, the
576 /// `RecordBatchReceiverStream` stops early and does not drive
577 /// other streams to completion.
578 async fn record_batch_receiver_stream_error_does_not_drive_completion() {
579 let task_ctx = Arc::new(TaskContext::default());
580 let schema = schema();
581
582 // make an input that will error twice
583 let error_stream = MockExec::new(
584 vec![exec_err!("Test1"), exec_err!("Test2")],
585 Arc::clone(&schema),
586 )
587 .with_use_task(false);
588
589 let mut builder = RecordBatchReceiverStream::builder(schema, 2);
590 builder.run_input(Arc::new(error_stream), 0, Arc::clone(&task_ctx));
591 let mut stream = builder.build();
592
593 // Get the first result, which should be an error
594 let first_batch = stream.next().await.unwrap();
595 let first_err = first_batch.unwrap_err();
596 assert_eq!(first_err.strip_backtrace(), "Execution error: Test1");
597
598 // There should be no more batches produced (should not get the second error)
599 assert!(stream.next().await.is_none());
600 }
601
602 /// Consumes all the input's partitions into a
603 /// RecordBatchReceiverStream and runs it to completion
604 ///
605 /// panic's if more than max_batches is seen,
606 async fn consume(input: PanicExec, max_batches: usize) {
607 let task_ctx = Arc::new(TaskContext::default());
608
609 let input = Arc::new(input);
610 let num_partitions = input.properties().output_partitioning().partition_count();
611
612 // Configure a RecordBatchReceiverStream to consume all the input partitions
613 let mut builder =
614 RecordBatchReceiverStream::builder(input.schema(), num_partitions);
615 for partition in 0..num_partitions {
616 builder.run_input(
617 Arc::clone(&input) as Arc<dyn ExecutionPlan>,
618 partition,
619 Arc::clone(&task_ctx),
620 );
621 }
622 let mut stream = builder.build();
623
624 // Drain the stream until it is complete, panic'ing on error
625 let mut num_batches = 0;
626 while let Some(next) = stream.next().await {
627 next.unwrap();
628 num_batches += 1;
629 assert!(
630 num_batches < max_batches,
631 "Got the limit of {num_batches} batches before seeing panic"
632 );
633 }
634 }
635}