datafusion_expr/logical_plan/
extension.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! This module defines the interface for logical nodes
19use crate::{Expr, LogicalPlan};
20use datafusion_common::{DFSchema, DFSchemaRef, Result};
21use std::cmp::Ordering;
22use std::hash::{Hash, Hasher};
23use std::{any::Any, collections::HashSet, fmt, sync::Arc};
24
25use super::InvariantLevel;
26
27/// This defines the interface for [`LogicalPlan`] nodes that can be
28/// used to extend DataFusion with custom relational operators.
29///
30/// The [`UserDefinedLogicalNodeCore`] trait is *the recommended way to implement*
31/// this trait and avoids having implementing some required boiler plate code.
32pub trait UserDefinedLogicalNode: fmt::Debug + Send + Sync {
33    /// Return a reference to self as Any, to support dynamic downcasting
34    ///
35    /// Typically this will look like:
36    ///
37    /// ```
38    /// # use std::any::Any;
39    /// # struct Dummy { }
40    ///
41    /// # impl Dummy {
42    ///   // canonical boiler plate
43    ///   fn as_any(&self) -> &dyn Any {
44    ///      self
45    ///   }
46    /// # }
47    /// ```
48    fn as_any(&self) -> &dyn Any;
49
50    /// Return the plan's name.
51    fn name(&self) -> &str;
52
53    /// Return the logical plan's inputs.
54    fn inputs(&self) -> Vec<&LogicalPlan>;
55
56    /// Return the output schema of this logical plan node.
57    fn schema(&self) -> &DFSchemaRef;
58
59    /// Perform check of invariants for the extension node.
60    fn check_invariants(&self, check: InvariantLevel, plan: &LogicalPlan) -> Result<()>;
61
62    /// Returns all expressions in the current logical plan node. This should
63    /// not include expressions of any inputs (aka non-recursively).
64    ///
65    /// These expressions are used for optimizer
66    /// passes and rewrites. See [`LogicalPlan::expressions`] for more details.
67    fn expressions(&self) -> Vec<Expr>;
68
69    /// A list of output columns (e.g. the names of columns in
70    /// self.schema()) for which predicates can not be pushed below
71    /// this node without changing the output.
72    ///
73    /// By default, this returns all columns and thus prevents any
74    /// predicates from being pushed below this node.
75    fn prevent_predicate_push_down_columns(&self) -> HashSet<String> {
76        // default (safe) is all columns in the schema.
77        get_all_columns_from_schema(self.schema())
78    }
79
80    /// Write a single line, human readable string to `f` for use in explain plan.
81    ///
82    /// For example: `TopK: k=10`
83    fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result;
84
85    #[deprecated(since = "39.0.0", note = "use with_exprs_and_inputs instead")]
86    #[allow(clippy::wrong_self_convention)]
87    fn from_template(
88        &self,
89        exprs: &[Expr],
90        inputs: &[LogicalPlan],
91    ) -> Arc<dyn UserDefinedLogicalNode> {
92        self.with_exprs_and_inputs(exprs.to_vec(), inputs.to_vec())
93            .unwrap()
94    }
95
96    /// Create a new `UserDefinedLogicalNode` with the specified children
97    /// and expressions. This function is used during optimization
98    /// when the plan is being rewritten and a new instance of the
99    /// `UserDefinedLogicalNode` must be created.
100    ///
101    /// Note that exprs and inputs are in the same order as the result
102    /// of self.inputs and self.exprs.
103    ///
104    /// So, `self.with_exprs_and_inputs(exprs, ..).expressions() == exprs
105    fn with_exprs_and_inputs(
106        &self,
107        exprs: Vec<Expr>,
108        inputs: Vec<LogicalPlan>,
109    ) -> Result<Arc<dyn UserDefinedLogicalNode>>;
110
111    /// Returns the necessary input columns for this node required to compute
112    /// the columns in the output schema
113    ///
114    /// This is used for projection push-down when DataFusion has determined that
115    /// only a subset of the output columns of this node are needed by its parents.
116    /// This API is used to tell DataFusion which, if any, of the input columns are no longer
117    /// needed.
118    ///
119    /// Return `None`, the default, if this information can not be determined.
120    /// Returns `Some(_)` with the column indices for each child of this node that are
121    /// needed to compute `output_columns`
122    fn necessary_children_exprs(
123        &self,
124        _output_columns: &[usize],
125    ) -> Option<Vec<Vec<usize>>> {
126        None
127    }
128
129    /// Update the hash `state` with this node requirements from
130    /// [`Hash`].
131    ///
132    /// Note: consider using [`UserDefinedLogicalNodeCore`] instead of
133    /// [`UserDefinedLogicalNode`] directly.
134    ///
135    /// This method is required to support hashing [`LogicalPlan`]s.  To
136    /// implement it, typically the type implementing
137    /// [`UserDefinedLogicalNode`] typically implements [`Hash`] and
138    /// then the following boiler plate is used:
139    ///
140    /// # Example:
141    /// ```
142    /// // User defined node that derives Hash
143    /// #[derive(Hash, Debug, PartialEq, Eq)]
144    /// struct MyNode {
145    ///   val: u64
146    /// }
147    ///
148    /// // impl UserDefinedLogicalNode {
149    /// // ...
150    /// # impl MyNode {
151    ///   // Boiler plate to call the derived Hash impl
152    ///   fn dyn_hash(&self, state: &mut dyn std::hash::Hasher) {
153    ///     use std::hash::Hash;
154    ///     let mut s = state;
155    ///     self.hash(&mut s);
156    ///   }
157    /// // }
158    /// # }
159    /// ```
160    /// Note: [`UserDefinedLogicalNode`] is not constrained by [`Hash`]
161    /// directly because it must remain object safe.
162    fn dyn_hash(&self, state: &mut dyn Hasher);
163
164    /// Compare `other`, respecting requirements from [std::cmp::Eq].
165    ///
166    /// Note: consider using [`UserDefinedLogicalNodeCore`] instead of
167    /// [`UserDefinedLogicalNode`] directly.
168    ///
169    /// When `other` has an another type than `self`, then the values
170    /// are *not* equal.
171    ///
172    /// This method is required to support Eq on [`LogicalPlan`]s.  To
173    /// implement it, typically the type implementing
174    /// [`UserDefinedLogicalNode`] typically implements [`Eq`] and
175    /// then the following boiler plate is used:
176    ///
177    /// # Example:
178    /// ```
179    /// # use datafusion_expr::UserDefinedLogicalNode;
180    /// // User defined node that derives Eq
181    /// #[derive(Hash, Debug, PartialEq, Eq)]
182    /// struct MyNode {
183    ///   val: u64
184    /// }
185    ///
186    /// // impl UserDefinedLogicalNode {
187    /// // ...
188    /// # impl MyNode {
189    ///   // Boiler plate to call the derived Eq impl
190    ///   fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool {
191    ///     match other.as_any().downcast_ref::<Self>() {
192    ///       Some(o) => self == o,
193    ///       None => false,
194    ///     }
195    ///   }
196    /// // }
197    /// # }
198    /// ```
199    /// Note: [`UserDefinedLogicalNode`] is not constrained by [`Eq`]
200    /// directly because it must remain object safe.
201    fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool;
202    fn dyn_ord(&self, other: &dyn UserDefinedLogicalNode) -> Option<Ordering>;
203
204    /// Returns `true` if a limit can be safely pushed down through this
205    /// `UserDefinedLogicalNode` node.
206    ///
207    /// If this method returns `true`, and the query plan contains a limit at
208    /// the output of this node, DataFusion will push the limit to the input
209    /// of this node.
210    fn supports_limit_pushdown(&self) -> bool {
211        false
212    }
213}
214
215impl Hash for dyn UserDefinedLogicalNode {
216    fn hash<H: Hasher>(&self, state: &mut H) {
217        self.dyn_hash(state);
218    }
219}
220
221impl PartialEq for dyn UserDefinedLogicalNode {
222    fn eq(&self, other: &Self) -> bool {
223        self.dyn_eq(other)
224    }
225}
226
227impl PartialOrd for dyn UserDefinedLogicalNode {
228    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
229        self.dyn_ord(other)
230    }
231}
232
233impl Eq for dyn UserDefinedLogicalNode {}
234
235/// This trait facilitates implementation of the [`UserDefinedLogicalNode`].
236///
237/// See the example in
238/// [user_defined_plan.rs](https://github.com/apache/datafusion/blob/main/datafusion/core/tests/user_defined/user_defined_plan.rs)
239/// file for an example of how to use this extension API.
240pub trait UserDefinedLogicalNodeCore:
241    fmt::Debug + Eq + PartialOrd + Hash + Sized + Send + Sync + 'static
242{
243    /// Return the plan's name.
244    fn name(&self) -> &str;
245
246    /// Return the logical plan's inputs.
247    fn inputs(&self) -> Vec<&LogicalPlan>;
248
249    /// Return the output schema of this logical plan node.
250    fn schema(&self) -> &DFSchemaRef;
251
252    /// Perform check of invariants for the extension node.
253    ///
254    /// This is the default implementation for extension nodes.
255    fn check_invariants(
256        &self,
257        _check: InvariantLevel,
258        _plan: &LogicalPlan,
259    ) -> Result<()> {
260        Ok(())
261    }
262
263    /// Returns all expressions in the current logical plan node. This
264    /// should not include expressions of any inputs (aka
265    /// non-recursively). These expressions are used for optimizer
266    /// passes and rewrites.
267    fn expressions(&self) -> Vec<Expr>;
268
269    /// A list of output columns (e.g. the names of columns in
270    /// self.schema()) for which predicates can not be pushed below
271    /// this node without changing the output.
272    ///
273    /// By default, this returns all columns and thus prevents any
274    /// predicates from being pushed below this node.
275    fn prevent_predicate_push_down_columns(&self) -> HashSet<String> {
276        // default (safe) is all columns in the schema.
277        get_all_columns_from_schema(self.schema())
278    }
279
280    /// Write a single line, human readable string to `f` for use in explain plan.
281    ///
282    /// For example: `TopK: k=10`
283    fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result;
284
285    #[deprecated(since = "39.0.0", note = "use with_exprs_and_inputs instead")]
286    #[allow(clippy::wrong_self_convention)]
287    fn from_template(&self, exprs: &[Expr], inputs: &[LogicalPlan]) -> Self {
288        self.with_exprs_and_inputs(exprs.to_vec(), inputs.to_vec())
289            .unwrap()
290    }
291
292    /// Create a new `UserDefinedLogicalNode` with the specified children
293    /// and expressions. This function is used during optimization
294    /// when the plan is being rewritten and a new instance of the
295    /// `UserDefinedLogicalNode` must be created.
296    ///
297    /// Note that exprs and inputs are in the same order as the result
298    /// of self.inputs and self.exprs.
299    ///
300    /// So, `self.with_exprs_and_inputs(exprs, ..).expressions() == exprs
301    fn with_exprs_and_inputs(
302        &self,
303        exprs: Vec<Expr>,
304        inputs: Vec<LogicalPlan>,
305    ) -> Result<Self>;
306
307    /// Returns the necessary input columns for this node required to compute
308    /// the columns in the output schema
309    ///
310    /// This is used for projection push-down when DataFusion has determined that
311    /// only a subset of the output columns of this node are needed by its parents.
312    /// This API is used to tell DataFusion which, if any, of the input columns are no longer
313    /// needed.
314    ///
315    /// Return `None`, the default, if this information can not be determined.
316    /// Returns `Some(_)` with the column indices for each child of this node that are
317    /// needed to compute `output_columns`
318    fn necessary_children_exprs(
319        &self,
320        _output_columns: &[usize],
321    ) -> Option<Vec<Vec<usize>>> {
322        None
323    }
324
325    /// Returns `true` if a limit can be safely pushed down through this
326    /// `UserDefinedLogicalNode` node.
327    ///
328    /// If this method returns `true`, and the query plan contains a limit at
329    /// the output of this node, DataFusion will push the limit to the input
330    /// of this node.
331    fn supports_limit_pushdown(&self) -> bool {
332        false // Disallow limit push-down by default
333    }
334}
335
336/// Automatically derive UserDefinedLogicalNode to `UserDefinedLogicalNode`
337/// to avoid boiler plate for implementing `as_any`, `Hash` and `PartialEq`
338impl<T: UserDefinedLogicalNodeCore> UserDefinedLogicalNode for T {
339    fn as_any(&self) -> &dyn Any {
340        self
341    }
342
343    fn name(&self) -> &str {
344        self.name()
345    }
346
347    fn inputs(&self) -> Vec<&LogicalPlan> {
348        self.inputs()
349    }
350
351    fn schema(&self) -> &DFSchemaRef {
352        self.schema()
353    }
354
355    fn check_invariants(&self, check: InvariantLevel, plan: &LogicalPlan) -> Result<()> {
356        self.check_invariants(check, plan)
357    }
358
359    fn expressions(&self) -> Vec<Expr> {
360        self.expressions()
361    }
362
363    fn prevent_predicate_push_down_columns(&self) -> HashSet<String> {
364        self.prevent_predicate_push_down_columns()
365    }
366
367    fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result {
368        self.fmt_for_explain(f)
369    }
370
371    fn with_exprs_and_inputs(
372        &self,
373        exprs: Vec<Expr>,
374        inputs: Vec<LogicalPlan>,
375    ) -> Result<Arc<dyn UserDefinedLogicalNode>> {
376        Ok(Arc::new(self.with_exprs_and_inputs(exprs, inputs)?))
377    }
378
379    fn necessary_children_exprs(
380        &self,
381        output_columns: &[usize],
382    ) -> Option<Vec<Vec<usize>>> {
383        self.necessary_children_exprs(output_columns)
384    }
385
386    fn dyn_hash(&self, state: &mut dyn Hasher) {
387        let mut s = state;
388        self.hash(&mut s);
389    }
390
391    fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool {
392        match other.as_any().downcast_ref::<Self>() {
393            Some(o) => self == o,
394            None => false,
395        }
396    }
397
398    fn dyn_ord(&self, other: &dyn UserDefinedLogicalNode) -> Option<Ordering> {
399        other
400            .as_any()
401            .downcast_ref::<Self>()
402            .and_then(|other| self.partial_cmp(other))
403    }
404
405    fn supports_limit_pushdown(&self) -> bool {
406        self.supports_limit_pushdown()
407    }
408}
409
410fn get_all_columns_from_schema(schema: &DFSchema) -> HashSet<String> {
411    schema.fields().iter().map(|f| f.name().clone()).collect()
412}