datafusion_physical_plan/
recursive_query.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines the recursive query plan
19
20use std::any::Any;
21use std::sync::Arc;
22use std::task::{Context, Poll};
23
24use super::work_table::{ReservedBatches, WorkTable, WorkTableExec};
25use crate::execution_plan::{Boundedness, EmissionType};
26use crate::{
27    metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet},
28    PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics,
29};
30use crate::{DisplayAs, DisplayFormatType, ExecutionPlan};
31
32use arrow::datatypes::SchemaRef;
33use arrow::record_batch::RecordBatch;
34use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
35use datafusion_common::{not_impl_err, DataFusionError, Result};
36use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
37use datafusion_execution::TaskContext;
38use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
39
40use futures::{ready, Stream, StreamExt};
41
42/// Recursive query execution plan.
43///
44/// This plan has two components: a base part (the static term) and
45/// a dynamic part (the recursive term). The execution will start from
46/// the base, and as long as the previous iteration produced at least
47/// a single new row (taking care of the distinction) the recursive
48/// part will be continuously executed.
49///
50/// Before each execution of the dynamic part, the rows from the previous
51/// iteration will be available in a "working table" (not a real table,
52/// can be only accessed using a continuance operation).
53///
54/// Note that there won't be any limit or checks applied to detect
55/// an infinite recursion, so it is up to the planner to ensure that
56/// it won't happen.
57#[derive(Debug, Clone)]
58pub struct RecursiveQueryExec {
59    /// Name of the query handler
60    name: String,
61    /// The working table of cte
62    work_table: Arc<WorkTable>,
63    /// The base part (static term)
64    static_term: Arc<dyn ExecutionPlan>,
65    /// The dynamic part (recursive term)
66    recursive_term: Arc<dyn ExecutionPlan>,
67    /// Distinction
68    is_distinct: bool,
69    /// Execution metrics
70    metrics: ExecutionPlanMetricsSet,
71    /// Cache holding plan properties like equivalences, output partitioning etc.
72    cache: PlanProperties,
73}
74
75impl RecursiveQueryExec {
76    /// Create a new RecursiveQueryExec
77    pub fn try_new(
78        name: String,
79        static_term: Arc<dyn ExecutionPlan>,
80        recursive_term: Arc<dyn ExecutionPlan>,
81        is_distinct: bool,
82    ) -> Result<Self> {
83        // Each recursive query needs its own work table
84        let work_table = Arc::new(WorkTable::new());
85        // Use the same work table for both the WorkTableExec and the recursive term
86        let recursive_term = assign_work_table(recursive_term, Arc::clone(&work_table))?;
87        let cache = Self::compute_properties(static_term.schema());
88        Ok(RecursiveQueryExec {
89            name,
90            static_term,
91            recursive_term,
92            is_distinct,
93            work_table,
94            metrics: ExecutionPlanMetricsSet::new(),
95            cache,
96        })
97    }
98
99    /// Ref to name
100    pub fn name(&self) -> &str {
101        &self.name
102    }
103
104    /// Ref to static term
105    pub fn static_term(&self) -> &Arc<dyn ExecutionPlan> {
106        &self.static_term
107    }
108
109    /// Ref to recursive term
110    pub fn recursive_term(&self) -> &Arc<dyn ExecutionPlan> {
111        &self.recursive_term
112    }
113
114    /// is distinct
115    pub fn is_distinct(&self) -> bool {
116        self.is_distinct
117    }
118
119    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
120    fn compute_properties(schema: SchemaRef) -> PlanProperties {
121        let eq_properties = EquivalenceProperties::new(schema);
122
123        PlanProperties::new(
124            eq_properties,
125            Partitioning::UnknownPartitioning(1),
126            EmissionType::Incremental,
127            Boundedness::Bounded,
128        )
129    }
130}
131
132impl ExecutionPlan for RecursiveQueryExec {
133    fn name(&self) -> &'static str {
134        "RecursiveQueryExec"
135    }
136
137    fn as_any(&self) -> &dyn Any {
138        self
139    }
140
141    fn properties(&self) -> &PlanProperties {
142        &self.cache
143    }
144
145    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
146        vec![&self.static_term, &self.recursive_term]
147    }
148
149    // TODO: control these hints and see whether we can
150    // infer some from the child plans (static/recursive terms).
151    fn maintains_input_order(&self) -> Vec<bool> {
152        vec![false, false]
153    }
154
155    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
156        vec![false, false]
157    }
158
159    fn required_input_distribution(&self) -> Vec<crate::Distribution> {
160        vec![
161            crate::Distribution::SinglePartition,
162            crate::Distribution::SinglePartition,
163        ]
164    }
165
166    fn with_new_children(
167        self: Arc<Self>,
168        children: Vec<Arc<dyn ExecutionPlan>>,
169    ) -> Result<Arc<dyn ExecutionPlan>> {
170        RecursiveQueryExec::try_new(
171            self.name.clone(),
172            Arc::clone(&children[0]),
173            Arc::clone(&children[1]),
174            self.is_distinct,
175        )
176        .map(|e| Arc::new(e) as _)
177    }
178
179    fn execute(
180        &self,
181        partition: usize,
182        context: Arc<TaskContext>,
183    ) -> Result<SendableRecordBatchStream> {
184        // TODO: we might be able to handle multiple partitions in the future.
185        if partition != 0 {
186            return Err(DataFusionError::Internal(format!(
187                "RecursiveQueryExec got an invalid partition {} (expected 0)",
188                partition
189            )));
190        }
191
192        let static_stream = self.static_term.execute(partition, Arc::clone(&context))?;
193        let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
194        Ok(Box::pin(RecursiveQueryStream::new(
195            context,
196            Arc::clone(&self.work_table),
197            Arc::clone(&self.recursive_term),
198            static_stream,
199            baseline_metrics,
200        )))
201    }
202
203    fn metrics(&self) -> Option<MetricsSet> {
204        Some(self.metrics.clone_inner())
205    }
206
207    fn statistics(&self) -> Result<Statistics> {
208        Ok(Statistics::new_unknown(&self.schema()))
209    }
210}
211
212impl DisplayAs for RecursiveQueryExec {
213    fn fmt_as(
214        &self,
215        t: DisplayFormatType,
216        f: &mut std::fmt::Formatter,
217    ) -> std::fmt::Result {
218        match t {
219            DisplayFormatType::Default | DisplayFormatType::Verbose => {
220                write!(
221                    f,
222                    "RecursiveQueryExec: name={}, is_distinct={}",
223                    self.name, self.is_distinct
224                )
225            }
226        }
227    }
228}
229
230/// The actual logic of the recursive queries happens during the streaming
231/// process. A simplified version of the algorithm is the following:
232///
233/// buffer = []
234///
235/// while batch := static_stream.next():
236///    buffer.push(batch)
237///    yield buffer
238///
239/// while buffer.len() > 0:
240///    sender, receiver = Channel()
241///    register_continuation(handle_name, receiver)
242///    sender.send(buffer.drain())
243///    recursive_stream = recursive_term.execute()
244///    while batch := recursive_stream.next():
245///        buffer.append(batch)
246///        yield buffer
247///
248struct RecursiveQueryStream {
249    /// The context to be used for managing handlers & executing new tasks
250    task_context: Arc<TaskContext>,
251    /// The working table state, representing the self referencing cte table
252    work_table: Arc<WorkTable>,
253    /// The dynamic part (recursive term) as is (without being executed)
254    recursive_term: Arc<dyn ExecutionPlan>,
255    /// The static part (static term) as a stream. If the processing of this
256    /// part is completed, then it will be None.
257    static_stream: Option<SendableRecordBatchStream>,
258    /// The dynamic part (recursive term) as a stream. If the processing of this
259    /// part has not started yet, or has been completed, then it will be None.
260    recursive_stream: Option<SendableRecordBatchStream>,
261    /// The schema of the output.
262    schema: SchemaRef,
263    /// In-memory buffer for storing a copy of the current results. Will be
264    /// cleared after each iteration.
265    buffer: Vec<RecordBatch>,
266    /// Tracks the memory used by the buffer
267    reservation: MemoryReservation,
268    // /// Metrics.
269    _baseline_metrics: BaselineMetrics,
270}
271
272impl RecursiveQueryStream {
273    /// Create a new recursive query stream
274    fn new(
275        task_context: Arc<TaskContext>,
276        work_table: Arc<WorkTable>,
277        recursive_term: Arc<dyn ExecutionPlan>,
278        static_stream: SendableRecordBatchStream,
279        baseline_metrics: BaselineMetrics,
280    ) -> Self {
281        let schema = static_stream.schema();
282        let reservation =
283            MemoryConsumer::new("RecursiveQuery").register(task_context.memory_pool());
284        Self {
285            task_context,
286            work_table,
287            recursive_term,
288            static_stream: Some(static_stream),
289            recursive_stream: None,
290            schema,
291            buffer: vec![],
292            reservation,
293            _baseline_metrics: baseline_metrics,
294        }
295    }
296
297    /// Push a clone of the given batch to the in memory buffer, and then return
298    /// a poll with it.
299    fn push_batch(
300        mut self: std::pin::Pin<&mut Self>,
301        batch: RecordBatch,
302    ) -> Poll<Option<Result<RecordBatch>>> {
303        if let Err(e) = self.reservation.try_grow(batch.get_array_memory_size()) {
304            return Poll::Ready(Some(Err(e)));
305        }
306
307        self.buffer.push(batch.clone());
308        Poll::Ready(Some(Ok(batch)))
309    }
310
311    /// Start polling for the next iteration, will be called either after the static term
312    /// is completed or another term is completed. It will follow the algorithm above on
313    /// to check whether the recursion has ended.
314    fn poll_next_iteration(
315        mut self: std::pin::Pin<&mut Self>,
316        cx: &mut Context<'_>,
317    ) -> Poll<Option<Result<RecordBatch>>> {
318        let total_length = self
319            .buffer
320            .iter()
321            .fold(0, |acc, batch| acc + batch.num_rows());
322
323        if total_length == 0 {
324            return Poll::Ready(None);
325        }
326
327        // Update the work table with the current buffer
328        let reserved_batches = ReservedBatches::new(
329            std::mem::take(&mut self.buffer),
330            self.reservation.take(),
331        );
332        self.work_table.update(reserved_batches);
333
334        // We always execute (and re-execute iteratively) the first partition.
335        // Downstream plans should not expect any partitioning.
336        let partition = 0;
337
338        let recursive_plan = reset_plan_states(Arc::clone(&self.recursive_term))?;
339        self.recursive_stream =
340            Some(recursive_plan.execute(partition, Arc::clone(&self.task_context))?);
341        self.poll_next(cx)
342    }
343}
344
345fn assign_work_table(
346    plan: Arc<dyn ExecutionPlan>,
347    work_table: Arc<WorkTable>,
348) -> Result<Arc<dyn ExecutionPlan>> {
349    let mut work_table_refs = 0;
350    plan.transform_down(|plan| {
351        if let Some(exec) = plan.as_any().downcast_ref::<WorkTableExec>() {
352            if work_table_refs > 0 {
353                not_impl_err!(
354                    "Multiple recursive references to the same CTE are not supported"
355                )
356            } else {
357                work_table_refs += 1;
358                Ok(Transformed::yes(Arc::new(
359                    exec.with_work_table(Arc::clone(&work_table)),
360                )))
361            }
362        } else if plan.as_any().is::<RecursiveQueryExec>() {
363            not_impl_err!("Recursive queries cannot be nested")
364        } else {
365            Ok(Transformed::no(plan))
366        }
367    })
368    .data()
369}
370
371/// Some plans will change their internal states after execution, making them unable to be executed again.
372/// This function uses `ExecutionPlan::with_new_children` to fork a new plan with initial states.
373///
374/// An example is `CrossJoinExec`, which loads the left table into memory and stores it in the plan.
375/// However, if the data of the left table is derived from the work table, it will become outdated
376/// as the work table changes. When the next iteration executes this plan again, we must clear the left table.
377fn reset_plan_states(plan: Arc<dyn ExecutionPlan>) -> Result<Arc<dyn ExecutionPlan>> {
378    plan.transform_up(|plan| {
379        // WorkTableExec's states have already been updated correctly.
380        if plan.as_any().is::<WorkTableExec>() {
381            Ok(Transformed::no(plan))
382        } else {
383            let new_plan = Arc::clone(&plan)
384                .with_new_children(plan.children().into_iter().cloned().collect())?;
385            Ok(Transformed::yes(new_plan))
386        }
387    })
388    .data()
389}
390
391impl Stream for RecursiveQueryStream {
392    type Item = Result<RecordBatch>;
393
394    fn poll_next(
395        mut self: std::pin::Pin<&mut Self>,
396        cx: &mut Context<'_>,
397    ) -> Poll<Option<Self::Item>> {
398        // TODO: we should use this poll to record some metrics!
399        if let Some(static_stream) = &mut self.static_stream {
400            // While the static term's stream is available, we'll be forwarding the batches from it (also
401            // saving them for the initial iteration of the recursive term).
402            let batch_result = ready!(static_stream.poll_next_unpin(cx));
403            match &batch_result {
404                None => {
405                    // Once this is done, we can start running the setup for the recursive term.
406                    self.static_stream = None;
407                    self.poll_next_iteration(cx)
408                }
409                Some(Ok(batch)) => self.push_batch(batch.clone()),
410                _ => Poll::Ready(batch_result),
411            }
412        } else if let Some(recursive_stream) = &mut self.recursive_stream {
413            let batch_result = ready!(recursive_stream.poll_next_unpin(cx));
414            match batch_result {
415                None => {
416                    self.recursive_stream = None;
417                    self.poll_next_iteration(cx)
418                }
419                Some(Ok(batch)) => self.push_batch(batch),
420                _ => Poll::Ready(batch_result),
421            }
422        } else {
423            Poll::Ready(None)
424        }
425    }
426}
427
428impl RecordBatchStream for RecursiveQueryStream {
429    /// Get the schema
430    fn schema(&self) -> SchemaRef {
431        Arc::clone(&self.schema)
432    }
433}
434
435#[cfg(test)]
436mod tests {}