datafusion_expr/logical_plan/extension.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! This module defines the interface for logical nodes
19use crate::{Expr, LogicalPlan};
20use datafusion_common::{DFSchema, DFSchemaRef, Result};
21use std::cmp::Ordering;
22use std::hash::{Hash, Hasher};
23use std::{any::Any, collections::HashSet, fmt, sync::Arc};
24
25use super::InvariantLevel;
26
27/// This defines the interface for [`LogicalPlan`] nodes that can be
28/// used to extend DataFusion with custom relational operators.
29///
30/// The [`UserDefinedLogicalNodeCore`] trait is *the recommended way to implement*
31/// this trait and avoids having implementing some required boiler plate code.
32pub trait UserDefinedLogicalNode: fmt::Debug + Send + Sync {
33 /// Return a reference to self as Any, to support dynamic downcasting
34 ///
35 /// Typically this will look like:
36 ///
37 /// ```
38 /// # use std::any::Any;
39 /// # struct Dummy { }
40 ///
41 /// # impl Dummy {
42 /// // canonical boiler plate
43 /// fn as_any(&self) -> &dyn Any {
44 /// self
45 /// }
46 /// # }
47 /// ```
48 fn as_any(&self) -> &dyn Any;
49
50 /// Return the plan's name.
51 fn name(&self) -> &str;
52
53 /// Return the logical plan's inputs.
54 fn inputs(&self) -> Vec<&LogicalPlan>;
55
56 /// Return the output schema of this logical plan node.
57 fn schema(&self) -> &DFSchemaRef;
58
59 /// Perform check of invariants for the extension node.
60 fn check_invariants(&self, check: InvariantLevel, plan: &LogicalPlan) -> Result<()>;
61
62 /// Returns all expressions in the current logical plan node. This should
63 /// not include expressions of any inputs (aka non-recursively).
64 ///
65 /// These expressions are used for optimizer
66 /// passes and rewrites. See [`LogicalPlan::expressions`] for more details.
67 fn expressions(&self) -> Vec<Expr>;
68
69 /// A list of output columns (e.g. the names of columns in
70 /// self.schema()) for which predicates can not be pushed below
71 /// this node without changing the output.
72 ///
73 /// By default, this returns all columns and thus prevents any
74 /// predicates from being pushed below this node.
75 fn prevent_predicate_push_down_columns(&self) -> HashSet<String> {
76 // default (safe) is all columns in the schema.
77 get_all_columns_from_schema(self.schema())
78 }
79
80 /// Write a single line, human readable string to `f` for use in explain plan.
81 ///
82 /// For example: `TopK: k=10`
83 fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result;
84
85 #[deprecated(since = "39.0.0", note = "use with_exprs_and_inputs instead")]
86 #[allow(clippy::wrong_self_convention)]
87 fn from_template(
88 &self,
89 exprs: &[Expr],
90 inputs: &[LogicalPlan],
91 ) -> Arc<dyn UserDefinedLogicalNode> {
92 self.with_exprs_and_inputs(exprs.to_vec(), inputs.to_vec())
93 .unwrap()
94 }
95
96 /// Create a new `UserDefinedLogicalNode` with the specified children
97 /// and expressions. This function is used during optimization
98 /// when the plan is being rewritten and a new instance of the
99 /// `UserDefinedLogicalNode` must be created.
100 ///
101 /// Note that exprs and inputs are in the same order as the result
102 /// of self.inputs and self.exprs.
103 ///
104 /// So, `self.with_exprs_and_inputs(exprs, ..).expressions() == exprs
105 fn with_exprs_and_inputs(
106 &self,
107 exprs: Vec<Expr>,
108 inputs: Vec<LogicalPlan>,
109 ) -> Result<Arc<dyn UserDefinedLogicalNode>>;
110
111 /// Returns the necessary input columns for this node required to compute
112 /// the columns in the output schema
113 ///
114 /// This is used for projection push-down when DataFusion has determined that
115 /// only a subset of the output columns of this node are needed by its parents.
116 /// This API is used to tell DataFusion which, if any, of the input columns are no longer
117 /// needed.
118 ///
119 /// Return `None`, the default, if this information can not be determined.
120 /// Returns `Some(_)` with the column indices for each child of this node that are
121 /// needed to compute `output_columns`
122 fn necessary_children_exprs(
123 &self,
124 _output_columns: &[usize],
125 ) -> Option<Vec<Vec<usize>>> {
126 None
127 }
128
129 /// Update the hash `state` with this node requirements from
130 /// [`Hash`].
131 ///
132 /// Note: consider using [`UserDefinedLogicalNodeCore`] instead of
133 /// [`UserDefinedLogicalNode`] directly.
134 ///
135 /// This method is required to support hashing [`LogicalPlan`]s. To
136 /// implement it, typically the type implementing
137 /// [`UserDefinedLogicalNode`] typically implements [`Hash`] and
138 /// then the following boiler plate is used:
139 ///
140 /// # Example:
141 /// ```
142 /// // User defined node that derives Hash
143 /// #[derive(Hash, Debug, PartialEq, Eq)]
144 /// struct MyNode {
145 /// val: u64
146 /// }
147 ///
148 /// // impl UserDefinedLogicalNode {
149 /// // ...
150 /// # impl MyNode {
151 /// // Boiler plate to call the derived Hash impl
152 /// fn dyn_hash(&self, state: &mut dyn std::hash::Hasher) {
153 /// use std::hash::Hash;
154 /// let mut s = state;
155 /// self.hash(&mut s);
156 /// }
157 /// // }
158 /// # }
159 /// ```
160 /// Note: [`UserDefinedLogicalNode`] is not constrained by [`Hash`]
161 /// directly because it must remain object safe.
162 fn dyn_hash(&self, state: &mut dyn Hasher);
163
164 /// Compare `other`, respecting requirements from [std::cmp::Eq].
165 ///
166 /// Note: consider using [`UserDefinedLogicalNodeCore`] instead of
167 /// [`UserDefinedLogicalNode`] directly.
168 ///
169 /// When `other` has an another type than `self`, then the values
170 /// are *not* equal.
171 ///
172 /// This method is required to support Eq on [`LogicalPlan`]s. To
173 /// implement it, typically the type implementing
174 /// [`UserDefinedLogicalNode`] typically implements [`Eq`] and
175 /// then the following boiler plate is used:
176 ///
177 /// # Example:
178 /// ```
179 /// # use datafusion_expr::UserDefinedLogicalNode;
180 /// // User defined node that derives Eq
181 /// #[derive(Hash, Debug, PartialEq, Eq)]
182 /// struct MyNode {
183 /// val: u64
184 /// }
185 ///
186 /// // impl UserDefinedLogicalNode {
187 /// // ...
188 /// # impl MyNode {
189 /// // Boiler plate to call the derived Eq impl
190 /// fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool {
191 /// match other.as_any().downcast_ref::<Self>() {
192 /// Some(o) => self == o,
193 /// None => false,
194 /// }
195 /// }
196 /// // }
197 /// # }
198 /// ```
199 /// Note: [`UserDefinedLogicalNode`] is not constrained by [`Eq`]
200 /// directly because it must remain object safe.
201 fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool;
202 fn dyn_ord(&self, other: &dyn UserDefinedLogicalNode) -> Option<Ordering>;
203
204 /// Returns `true` if a limit can be safely pushed down through this
205 /// `UserDefinedLogicalNode` node.
206 ///
207 /// If this method returns `true`, and the query plan contains a limit at
208 /// the output of this node, DataFusion will push the limit to the input
209 /// of this node.
210 fn supports_limit_pushdown(&self) -> bool {
211 false
212 }
213}
214
215impl Hash for dyn UserDefinedLogicalNode {
216 fn hash<H: Hasher>(&self, state: &mut H) {
217 self.dyn_hash(state);
218 }
219}
220
221impl PartialEq for dyn UserDefinedLogicalNode {
222 fn eq(&self, other: &Self) -> bool {
223 self.dyn_eq(other)
224 }
225}
226
227impl PartialOrd for dyn UserDefinedLogicalNode {
228 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
229 self.dyn_ord(other)
230 }
231}
232
233impl Eq for dyn UserDefinedLogicalNode {}
234
235/// This trait facilitates implementation of the [`UserDefinedLogicalNode`].
236///
237/// See the example in
238/// [user_defined_plan.rs](https://github.com/apache/datafusion/blob/main/datafusion/core/tests/user_defined/user_defined_plan.rs)
239/// file for an example of how to use this extension API.
240pub trait UserDefinedLogicalNodeCore:
241 fmt::Debug + Eq + PartialOrd + Hash + Sized + Send + Sync + 'static
242{
243 /// Return the plan's name.
244 fn name(&self) -> &str;
245
246 /// Return the logical plan's inputs.
247 fn inputs(&self) -> Vec<&LogicalPlan>;
248
249 /// Return the output schema of this logical plan node.
250 fn schema(&self) -> &DFSchemaRef;
251
252 /// Perform check of invariants for the extension node.
253 ///
254 /// This is the default implementation for extension nodes.
255 fn check_invariants(
256 &self,
257 _check: InvariantLevel,
258 _plan: &LogicalPlan,
259 ) -> Result<()> {
260 Ok(())
261 }
262
263 /// Returns all expressions in the current logical plan node. This
264 /// should not include expressions of any inputs (aka
265 /// non-recursively). These expressions are used for optimizer
266 /// passes and rewrites.
267 fn expressions(&self) -> Vec<Expr>;
268
269 /// A list of output columns (e.g. the names of columns in
270 /// self.schema()) for which predicates can not be pushed below
271 /// this node without changing the output.
272 ///
273 /// By default, this returns all columns and thus prevents any
274 /// predicates from being pushed below this node.
275 fn prevent_predicate_push_down_columns(&self) -> HashSet<String> {
276 // default (safe) is all columns in the schema.
277 get_all_columns_from_schema(self.schema())
278 }
279
280 /// Write a single line, human readable string to `f` for use in explain plan.
281 ///
282 /// For example: `TopK: k=10`
283 fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result;
284
285 #[deprecated(since = "39.0.0", note = "use with_exprs_and_inputs instead")]
286 #[allow(clippy::wrong_self_convention)]
287 fn from_template(&self, exprs: &[Expr], inputs: &[LogicalPlan]) -> Self {
288 self.with_exprs_and_inputs(exprs.to_vec(), inputs.to_vec())
289 .unwrap()
290 }
291
292 /// Create a new `UserDefinedLogicalNode` with the specified children
293 /// and expressions. This function is used during optimization
294 /// when the plan is being rewritten and a new instance of the
295 /// `UserDefinedLogicalNode` must be created.
296 ///
297 /// Note that exprs and inputs are in the same order as the result
298 /// of self.inputs and self.exprs.
299 ///
300 /// So, `self.with_exprs_and_inputs(exprs, ..).expressions() == exprs
301 fn with_exprs_and_inputs(
302 &self,
303 exprs: Vec<Expr>,
304 inputs: Vec<LogicalPlan>,
305 ) -> Result<Self>;
306
307 /// Returns the necessary input columns for this node required to compute
308 /// the columns in the output schema
309 ///
310 /// This is used for projection push-down when DataFusion has determined that
311 /// only a subset of the output columns of this node are needed by its parents.
312 /// This API is used to tell DataFusion which, if any, of the input columns are no longer
313 /// needed.
314 ///
315 /// Return `None`, the default, if this information can not be determined.
316 /// Returns `Some(_)` with the column indices for each child of this node that are
317 /// needed to compute `output_columns`
318 fn necessary_children_exprs(
319 &self,
320 _output_columns: &[usize],
321 ) -> Option<Vec<Vec<usize>>> {
322 None
323 }
324
325 /// Returns `true` if a limit can be safely pushed down through this
326 /// `UserDefinedLogicalNode` node.
327 ///
328 /// If this method returns `true`, and the query plan contains a limit at
329 /// the output of this node, DataFusion will push the limit to the input
330 /// of this node.
331 fn supports_limit_pushdown(&self) -> bool {
332 false // Disallow limit push-down by default
333 }
334}
335
336/// Automatically derive UserDefinedLogicalNode to `UserDefinedLogicalNode`
337/// to avoid boiler plate for implementing `as_any`, `Hash` and `PartialEq`
338impl<T: UserDefinedLogicalNodeCore> UserDefinedLogicalNode for T {
339 fn as_any(&self) -> &dyn Any {
340 self
341 }
342
343 fn name(&self) -> &str {
344 self.name()
345 }
346
347 fn inputs(&self) -> Vec<&LogicalPlan> {
348 self.inputs()
349 }
350
351 fn schema(&self) -> &DFSchemaRef {
352 self.schema()
353 }
354
355 fn check_invariants(&self, check: InvariantLevel, plan: &LogicalPlan) -> Result<()> {
356 self.check_invariants(check, plan)
357 }
358
359 fn expressions(&self) -> Vec<Expr> {
360 self.expressions()
361 }
362
363 fn prevent_predicate_push_down_columns(&self) -> HashSet<String> {
364 self.prevent_predicate_push_down_columns()
365 }
366
367 fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result {
368 self.fmt_for_explain(f)
369 }
370
371 fn with_exprs_and_inputs(
372 &self,
373 exprs: Vec<Expr>,
374 inputs: Vec<LogicalPlan>,
375 ) -> Result<Arc<dyn UserDefinedLogicalNode>> {
376 Ok(Arc::new(self.with_exprs_and_inputs(exprs, inputs)?))
377 }
378
379 fn necessary_children_exprs(
380 &self,
381 output_columns: &[usize],
382 ) -> Option<Vec<Vec<usize>>> {
383 self.necessary_children_exprs(output_columns)
384 }
385
386 fn dyn_hash(&self, state: &mut dyn Hasher) {
387 let mut s = state;
388 self.hash(&mut s);
389 }
390
391 fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool {
392 match other.as_any().downcast_ref::<Self>() {
393 Some(o) => self == o,
394 None => false,
395 }
396 }
397
398 fn dyn_ord(&self, other: &dyn UserDefinedLogicalNode) -> Option<Ordering> {
399 other
400 .as_any()
401 .downcast_ref::<Self>()
402 .and_then(|other| self.partial_cmp(other))
403 }
404
405 fn supports_limit_pushdown(&self) -> bool {
406 self.supports_limit_pushdown()
407 }
408}
409
410fn get_all_columns_from_schema(schema: &DFSchema) -> HashSet<String> {
411 schema.fields().iter().map(|f| f.name().clone()).collect()
412}