datafusion_physical_plan/
recursive_query.rs1use std::any::Any;
21use std::sync::Arc;
22use std::task::{Context, Poll};
23
24use super::work_table::{ReservedBatches, WorkTable, WorkTableExec};
25use crate::execution_plan::{Boundedness, EmissionType};
26use crate::{
27 metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet},
28 PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics,
29};
30use crate::{DisplayAs, DisplayFormatType, ExecutionPlan};
31
32use arrow::datatypes::SchemaRef;
33use arrow::record_batch::RecordBatch;
34use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
35use datafusion_common::{not_impl_err, DataFusionError, Result};
36use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
37use datafusion_execution::TaskContext;
38use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
39
40use futures::{ready, Stream, StreamExt};
41
42#[derive(Debug, Clone)]
58pub struct RecursiveQueryExec {
59 name: String,
61 work_table: Arc<WorkTable>,
63 static_term: Arc<dyn ExecutionPlan>,
65 recursive_term: Arc<dyn ExecutionPlan>,
67 is_distinct: bool,
69 metrics: ExecutionPlanMetricsSet,
71 cache: PlanProperties,
73}
74
75impl RecursiveQueryExec {
76 pub fn try_new(
78 name: String,
79 static_term: Arc<dyn ExecutionPlan>,
80 recursive_term: Arc<dyn ExecutionPlan>,
81 is_distinct: bool,
82 ) -> Result<Self> {
83 let work_table = Arc::new(WorkTable::new());
85 let recursive_term = assign_work_table(recursive_term, Arc::clone(&work_table))?;
87 let cache = Self::compute_properties(static_term.schema());
88 Ok(RecursiveQueryExec {
89 name,
90 static_term,
91 recursive_term,
92 is_distinct,
93 work_table,
94 metrics: ExecutionPlanMetricsSet::new(),
95 cache,
96 })
97 }
98
99 pub fn name(&self) -> &str {
101 &self.name
102 }
103
104 pub fn static_term(&self) -> &Arc<dyn ExecutionPlan> {
106 &self.static_term
107 }
108
109 pub fn recursive_term(&self) -> &Arc<dyn ExecutionPlan> {
111 &self.recursive_term
112 }
113
114 pub fn is_distinct(&self) -> bool {
116 self.is_distinct
117 }
118
119 fn compute_properties(schema: SchemaRef) -> PlanProperties {
121 let eq_properties = EquivalenceProperties::new(schema);
122
123 PlanProperties::new(
124 eq_properties,
125 Partitioning::UnknownPartitioning(1),
126 EmissionType::Incremental,
127 Boundedness::Bounded,
128 )
129 }
130}
131
132impl ExecutionPlan for RecursiveQueryExec {
133 fn name(&self) -> &'static str {
134 "RecursiveQueryExec"
135 }
136
137 fn as_any(&self) -> &dyn Any {
138 self
139 }
140
141 fn properties(&self) -> &PlanProperties {
142 &self.cache
143 }
144
145 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
146 vec![&self.static_term, &self.recursive_term]
147 }
148
149 fn maintains_input_order(&self) -> Vec<bool> {
152 vec![false, false]
153 }
154
155 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
156 vec![false, false]
157 }
158
159 fn required_input_distribution(&self) -> Vec<crate::Distribution> {
160 vec![
161 crate::Distribution::SinglePartition,
162 crate::Distribution::SinglePartition,
163 ]
164 }
165
166 fn with_new_children(
167 self: Arc<Self>,
168 children: Vec<Arc<dyn ExecutionPlan>>,
169 ) -> Result<Arc<dyn ExecutionPlan>> {
170 RecursiveQueryExec::try_new(
171 self.name.clone(),
172 Arc::clone(&children[0]),
173 Arc::clone(&children[1]),
174 self.is_distinct,
175 )
176 .map(|e| Arc::new(e) as _)
177 }
178
179 fn execute(
180 &self,
181 partition: usize,
182 context: Arc<TaskContext>,
183 ) -> Result<SendableRecordBatchStream> {
184 if partition != 0 {
186 return Err(DataFusionError::Internal(format!(
187 "RecursiveQueryExec got an invalid partition {} (expected 0)",
188 partition
189 )));
190 }
191
192 let static_stream = self.static_term.execute(partition, Arc::clone(&context))?;
193 let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
194 Ok(Box::pin(RecursiveQueryStream::new(
195 context,
196 Arc::clone(&self.work_table),
197 Arc::clone(&self.recursive_term),
198 static_stream,
199 baseline_metrics,
200 )))
201 }
202
203 fn metrics(&self) -> Option<MetricsSet> {
204 Some(self.metrics.clone_inner())
205 }
206
207 fn statistics(&self) -> Result<Statistics> {
208 Ok(Statistics::new_unknown(&self.schema()))
209 }
210}
211
212impl DisplayAs for RecursiveQueryExec {
213 fn fmt_as(
214 &self,
215 t: DisplayFormatType,
216 f: &mut std::fmt::Formatter,
217 ) -> std::fmt::Result {
218 match t {
219 DisplayFormatType::Default | DisplayFormatType::Verbose => {
220 write!(
221 f,
222 "RecursiveQueryExec: name={}, is_distinct={}",
223 self.name, self.is_distinct
224 )
225 }
226 }
227 }
228}
229
230struct RecursiveQueryStream {
249 task_context: Arc<TaskContext>,
251 work_table: Arc<WorkTable>,
253 recursive_term: Arc<dyn ExecutionPlan>,
255 static_stream: Option<SendableRecordBatchStream>,
258 recursive_stream: Option<SendableRecordBatchStream>,
261 schema: SchemaRef,
263 buffer: Vec<RecordBatch>,
266 reservation: MemoryReservation,
268 _baseline_metrics: BaselineMetrics,
270}
271
272impl RecursiveQueryStream {
273 fn new(
275 task_context: Arc<TaskContext>,
276 work_table: Arc<WorkTable>,
277 recursive_term: Arc<dyn ExecutionPlan>,
278 static_stream: SendableRecordBatchStream,
279 baseline_metrics: BaselineMetrics,
280 ) -> Self {
281 let schema = static_stream.schema();
282 let reservation =
283 MemoryConsumer::new("RecursiveQuery").register(task_context.memory_pool());
284 Self {
285 task_context,
286 work_table,
287 recursive_term,
288 static_stream: Some(static_stream),
289 recursive_stream: None,
290 schema,
291 buffer: vec![],
292 reservation,
293 _baseline_metrics: baseline_metrics,
294 }
295 }
296
297 fn push_batch(
300 mut self: std::pin::Pin<&mut Self>,
301 batch: RecordBatch,
302 ) -> Poll<Option<Result<RecordBatch>>> {
303 if let Err(e) = self.reservation.try_grow(batch.get_array_memory_size()) {
304 return Poll::Ready(Some(Err(e)));
305 }
306
307 self.buffer.push(batch.clone());
308 Poll::Ready(Some(Ok(batch)))
309 }
310
311 fn poll_next_iteration(
315 mut self: std::pin::Pin<&mut Self>,
316 cx: &mut Context<'_>,
317 ) -> Poll<Option<Result<RecordBatch>>> {
318 let total_length = self
319 .buffer
320 .iter()
321 .fold(0, |acc, batch| acc + batch.num_rows());
322
323 if total_length == 0 {
324 return Poll::Ready(None);
325 }
326
327 let reserved_batches = ReservedBatches::new(
329 std::mem::take(&mut self.buffer),
330 self.reservation.take(),
331 );
332 self.work_table.update(reserved_batches);
333
334 let partition = 0;
337
338 let recursive_plan = reset_plan_states(Arc::clone(&self.recursive_term))?;
339 self.recursive_stream =
340 Some(recursive_plan.execute(partition, Arc::clone(&self.task_context))?);
341 self.poll_next(cx)
342 }
343}
344
345fn assign_work_table(
346 plan: Arc<dyn ExecutionPlan>,
347 work_table: Arc<WorkTable>,
348) -> Result<Arc<dyn ExecutionPlan>> {
349 let mut work_table_refs = 0;
350 plan.transform_down(|plan| {
351 if let Some(exec) = plan.as_any().downcast_ref::<WorkTableExec>() {
352 if work_table_refs > 0 {
353 not_impl_err!(
354 "Multiple recursive references to the same CTE are not supported"
355 )
356 } else {
357 work_table_refs += 1;
358 Ok(Transformed::yes(Arc::new(
359 exec.with_work_table(Arc::clone(&work_table)),
360 )))
361 }
362 } else if plan.as_any().is::<RecursiveQueryExec>() {
363 not_impl_err!("Recursive queries cannot be nested")
364 } else {
365 Ok(Transformed::no(plan))
366 }
367 })
368 .data()
369}
370
371fn reset_plan_states(plan: Arc<dyn ExecutionPlan>) -> Result<Arc<dyn ExecutionPlan>> {
378 plan.transform_up(|plan| {
379 if plan.as_any().is::<WorkTableExec>() {
381 Ok(Transformed::no(plan))
382 } else {
383 let new_plan = Arc::clone(&plan)
384 .with_new_children(plan.children().into_iter().cloned().collect())?;
385 Ok(Transformed::yes(new_plan))
386 }
387 })
388 .data()
389}
390
391impl Stream for RecursiveQueryStream {
392 type Item = Result<RecordBatch>;
393
394 fn poll_next(
395 mut self: std::pin::Pin<&mut Self>,
396 cx: &mut Context<'_>,
397 ) -> Poll<Option<Self::Item>> {
398 if let Some(static_stream) = &mut self.static_stream {
400 let batch_result = ready!(static_stream.poll_next_unpin(cx));
403 match &batch_result {
404 None => {
405 self.static_stream = None;
407 self.poll_next_iteration(cx)
408 }
409 Some(Ok(batch)) => self.push_batch(batch.clone()),
410 _ => Poll::Ready(batch_result),
411 }
412 } else if let Some(recursive_stream) = &mut self.recursive_stream {
413 let batch_result = ready!(recursive_stream.poll_next_unpin(cx));
414 match batch_result {
415 None => {
416 self.recursive_stream = None;
417 self.poll_next_iteration(cx)
418 }
419 Some(Ok(batch)) => self.push_batch(batch),
420 _ => Poll::Ready(batch_result),
421 }
422 } else {
423 Poll::Ready(None)
424 }
425 }
426}
427
428impl RecordBatchStream for RecursiveQueryStream {
429 fn schema(&self) -> SchemaRef {
431 Arc::clone(&self.schema)
432 }
433}
434
435#[cfg(test)]
436mod tests {}